Class: Transformers::DebertaV2::DebertaV2LMPredictionHead
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::DebertaV2::DebertaV2LMPredictionHead
- Defined in:
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
Instance Method Summary collapse
- #_tie_weights ⇒ Object
- #forward(hidden_states) ⇒ Object
-
#initialize(config) ⇒ DebertaV2LMPredictionHead
constructor
A new instance of DebertaV2LMPredictionHead.
Constructor Details
#initialize(config) ⇒ DebertaV2LMPredictionHead
Returns a new instance of DebertaV2LMPredictionHead.
891 892 893 894 895 896 897 898 899 900 901 902 903 904 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 891 def initialize(config) super() @transform = DebertaV2PredictionHeadTransform.new(config) @embedding_size = config.getattr("embedding_size", config.hidden_size) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. @decoder = Torch::NN::Linear.new(@embedding_size, config.vocab_size, bias: false) @bias = Torch::NN::Parameter.new(Torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` @bias = @bias end |
Instance Method Details
#_tie_weights ⇒ Object
906 907 908 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 906 def _tie_weights @bias = @bias end |
#forward(hidden_states) ⇒ Object
910 911 912 913 914 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 910 def forward(hidden_states) hidden_states = @transform.(hidden_states) hidden_states = @decoder.(hidden_states) hidden_states end |