Class: Transformers::DebertaV2::DebertaV2ForMaskedLM

Inherits:
DebertaV2PreTrainedModel show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from DebertaV2PreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ DebertaV2ForMaskedLM

Returns a new instance of DebertaV2ForMaskedLM.



816
817
818
819
820
821
822
823
824
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 816

def initialize(config)
  super(config)

  @deberta = DebertaV2Model.new(config)
  @cls = DebertaV2OnlyMLMHead.new(config)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 835

def forward(
  input_ids: nil,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)

  sequence_output = outputs[0]
  prediction_scores = @cls.(sequence_output)

  masked_lm_loss = nil
  if !labels.nil?
    loss_fct = Torch::NN::CrossEntropyLoss.new
    masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1))
  end

  if !return_dict
    output = [prediction_scores] + outputs[1..]
    return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output
  end

  MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
end

#get_output_embeddingsObject



826
827
828
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 826

def get_output_embeddings
  @cls.predictions.decoder
end

#set_output_embeddings(new_embeddings) ⇒ Object



830
831
832
833
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 830

def set_output_embeddings(new_embeddings)
  @decoder = new_embeddings
  @bias = new_embeddings.bias
end