Class: Transformers::Bert::BertLMPredictionHead
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Bert::BertLMPredictionHead
- Defined in:
- lib/transformers/models/bert/modeling_bert.rb
Instance Method Summary collapse
- #_tie_weights ⇒ Object
- #forward(hidden_states) ⇒ Object
-
#initialize(config) ⇒ BertLMPredictionHead
constructor
A new instance of BertLMPredictionHead.
Constructor Details
#initialize(config) ⇒ BertLMPredictionHead
Returns a new instance of BertLMPredictionHead.
478 479 480 481 482 483 484 485 486 487 488 489 490 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 478 def initialize(config) super() @transform = BertPredictionHeadTransform.new(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. @decoder = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false) @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` @decoder.instance_variable_set(:@bias, @bias) end |
Instance Method Details
#_tie_weights ⇒ Object
492 493 494 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 492 def _tie_weights @decoder.instance_variable_set(:@bias, @bias) end |
#forward(hidden_states) ⇒ Object
496 497 498 499 500 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 496 def forward(hidden_states) hidden_states = @transform.(hidden_states) hidden_states = @decoder.(hidden_states) hidden_states end |