Class: Transformers::Mpnet::MPNetClassificationHead
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Mpnet::MPNetClassificationHead
- Defined in:
- lib/transformers/models/mpnet/modeling_mpnet.rb
Instance Method Summary collapse
- #forward(features, **kwargs) ⇒ Object
-
#initialize(config) ⇒ MPNetClassificationHead
constructor
A new instance of MPNetClassificationHead.
Constructor Details
#initialize(config) ⇒ MPNetClassificationHead
Returns a new instance of MPNetClassificationHead.
708 709 710 711 712 713 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 708 def initialize(config) super() @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size) @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob) @out_proj = Torch::NN::Linear.new(config.hidden_size, config.num_labels) end |
Instance Method Details
#forward(features, **kwargs) ⇒ Object
715 716 717 718 719 720 721 722 723 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 715 def forward(features, **kwargs) x = features[0.., 0, 0..] x = @dropout.(x) x = @dense.(x) x = Torch.tanh(x) x = @dropout.(x) x = @out_proj.(x) x end |