Class: Transformers::DebertaV2::DebertaV2LMPredictionHead

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ DebertaV2LMPredictionHead

Returns a new instance of DebertaV2LMPredictionHead.



891
892
893
894
895
896
897
898
899
900
901
902
903
904
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 891

def initialize(config)
  super()
  @transform = DebertaV2PredictionHeadTransform.new(config)

  @embedding_size = config.getattr("embedding_size", config.hidden_size)
  # The output weights are the same as the input embeddings, but there is
  # an output-only bias for each token.
  @decoder = Torch::NN::Linear.new(@embedding_size, config.vocab_size, bias: false)

  @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size))

  # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  @bias = @bias
end

Instance Method Details

#_tie_weightsObject



906
907
908
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 906

def _tie_weights
  @bias = @bias
end

#forward(hidden_states) ⇒ Object



910
911
912
913
914
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 910

def forward(hidden_states)
  hidden_states = @transform.(hidden_states)
  hidden_states = @decoder.(hidden_states)
  hidden_states
end