Class: Transformers::Vit::ViTPooler

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/vit/modeling_vit.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ ViTPooler

Returns a new instance of ViTPooler.



432
433
434
435
436
# File 'lib/transformers/models/vit/modeling_vit.rb', line 432

def initialize(config)
  super()
  @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
  @activation = Torch::NN::Tanh.new
end

Instance Method Details

#forward(hidden_states) ⇒ Object



438
439
440
441
442
443
444
445
# File 'lib/transformers/models/vit/modeling_vit.rb', line 438

def forward(hidden_states)
  # We "pool" the model by simply taking the hidden state corresponding
  # to the first token.
  first_token_tensor = hidden_states[0.., 0]
  pooled_output = @dense.(first_token_tensor)
  pooled_output = @activation.(pooled_output)
  pooled_output
end