Class: Transformers::DebertaV2::DebertaV2ForMaskedLM
- Inherits:
-
DebertaV2PreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- DebertaV2PreTrainedModel
- Transformers::DebertaV2::DebertaV2ForMaskedLM
- Defined in:
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
- #get_output_embeddings ⇒ Object
-
#initialize(config) ⇒ DebertaV2ForMaskedLM
constructor
A new instance of DebertaV2ForMaskedLM.
- #set_output_embeddings(new_embeddings) ⇒ Object
Methods inherited from DebertaV2PreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ DebertaV2ForMaskedLM
Returns a new instance of DebertaV2ForMaskedLM.
816 817 818 819 820 821 822 823 824 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 816 def initialize(config) super(config) @deberta = DebertaV2Model.new(config) @cls = DebertaV2OnlyMLMHead.new(config) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 835 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] prediction_scores = @cls.(sequence_output) masked_lm_loss = nil if !labels.nil? loss_fct = Torch::NN::CrossEntropyLoss.new masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1)) end if !return_dict output = [prediction_scores] + outputs[1..] return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output end MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |
#get_output_embeddings ⇒ Object
826 827 828 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 826 def @cls.predictions.decoder end |
#set_output_embeddings(new_embeddings) ⇒ Object
830 831 832 833 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 830 def () @decoder = @bias = .bias end |