Class: Transformers::XlmRoberta::XLMRobertaClassificationHead

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ XLMRobertaClassificationHead

Returns a new instance of XLMRobertaClassificationHead.



1132
1133
1134
1135
1136
1137
1138
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 1132

def initialize(config)
  super()
  @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
  classifier_dropout = !config.classifier_dropout.nil? ? config.classifier_dropout : config.hidden_dropout_prob
  @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
  @out_proj = Torch::NN::Linear.new(config.hidden_size, config.num_labels)
end

Instance Method Details

#forward(features, **kwargs) ⇒ Object



1140
1141
1142
1143
1144
1145
1146
1147
1148
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 1140

def forward(features, **kwargs)
  x = features[0.., 0, 0..]
  x = @dropout.(x)
  x = @dense.(x)
  x = Torch.tanh(x)
  x = @dropout.(x)
  x = @out_proj.(x)
  x
end