Class: Transformers::XlmRoberta::XLMRobertaLMHead

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ XLMRobertaLMHead

Returns a new instance of XLMRobertaLMHead.



920
921
922
923
924
925
926
927
928
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 920

def initialize(config)
  super()
  @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
  @layer_norm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)

  @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size)
  @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))
  @bias = @bias
end

Instance Method Details

#_tie_weightsObject



941
942
943
944
945
946
947
948
949
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 941

def _tie_weights
  # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
  # For accelerate compatibility and to not break backward compatibility
  if @decoder.bias.device.type == "meta"
    @bias = @bias
  else
    @bias = @decoder.bias
  end
end

#forward(features, **kwargs) ⇒ Object



930
931
932
933
934
935
936
937
938
939
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 930

def forward(features, **kwargs)
  x = @dense.(features)
  x = Activations.gelu(x)
  x = @layer_norm.(x)

  # project back to size of vocabulary with bias
  x = @decoder.(x)

  x
end