Class: Transformers::DebertaV2::ContextPooler

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ ContextPooler

Returns a new instance of ContextPooler.



18
19
20
21
22
23
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 18

def initialize(config)
  super()
  @dense = Torch::NN::Linear.new(config.pooler_hidden_size, config.pooler_hidden_size)
  @dropout = StableDropout.new(config.pooler_dropout)
  @config = config
end

Instance Method Details

#forward(hidden_states) ⇒ Object



25
26
27
28
29
30
31
32
33
34
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 25

def forward(hidden_states)
  # We "pool" the model by simply taking the hidden state corresponding
  # to the first token.

  context_token = hidden_states[0.., 0]
  context_token = @dropout.(context_token)
  pooled_output = @dense.(context_token)
  pooled_output = ACT2FN[@config.pooler_hidden_act].(pooled_output)
  pooled_output
end

#output_dimObject



36
37
38
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 36

def output_dim
  @config.hidden_size
end