Class: Transformers::Mpnet::MPNetClassificationHead

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ MPNetClassificationHead

Returns a new instance of MPNetClassificationHead.



708
709
710
711
712
713
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 708

def initialize(config)
  super()
  @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
  @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
  @out_proj = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
end

Instance Method Details

#forward(features, **kwargs) ⇒ Object



715
716
717
718
719
720
721
722
723
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 715

def forward(features, **kwargs)
  x = features[0.., 0, 0..]
  x = @dropout.(x)
  x = @dense.(x)
  x = Torch.tanh(x)
  x = @dropout.(x)
  x = @out_proj.(x)
  x
end