Class: Transformers::Bert::BertLMPredictionHead

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ BertLMPredictionHead

Returns a new instance of BertLMPredictionHead.



478
479
480
481
482
483
484
485
486
487
488
489
490
# File 'lib/transformers/models/bert/modeling_bert.rb', line 478

def initialize(config)
  super()
  @transform = BertPredictionHeadTransform.new(config)

  # The output weights are the same as the input embeddings, but there is
  # an output-only bias for each token.
  @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)

  @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))

  # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  @decoder.instance_variable_set(:@bias, @bias)
end

Instance Method Details

#_tie_weightsObject



492
493
494
# File 'lib/transformers/models/bert/modeling_bert.rb', line 492

def _tie_weights
  @decoder.instance_variable_set(:@bias, @bias)
end

#forward(hidden_states) ⇒ Object



496
497
498
499
500
# File 'lib/transformers/models/bert/modeling_bert.rb', line 496

def forward(hidden_states)
  hidden_states = @transform.(hidden_states)
  hidden_states = @decoder.(hidden_states)
  hidden_states
end