Class: Transformers::DebertaV2::DebertaV2Model

Inherits:
DebertaV2PreTrainedModel show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from DebertaV2PreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ DebertaV2Model

Returns a new instance of DebertaV2Model.



728
729
730
731
732
733
734
735
736
737
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 728

def initialize(config)
  super(config)

  @embeddings = DebertaV2Embeddings.new(config)
  @encoder = DebertaV2Encoder.new(config)
  @z_steps = 0
  @config = config
  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#_prune_heads(heads_to_prune) ⇒ Object

Raises:

  • (NotImplementedError)


747
748
749
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 747

def _prune_heads(heads_to_prune)
  raise NotImplementedError, "The prune function is not implemented in DeBERTa model."
end

#forward(input_ids, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 751

def forward(
  input_ids,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  inputs_embeds: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
  output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  if !input_ids.nil? && !inputs_embeds.nil?
    raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
  elsif !input_ids.nil?
    warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
    input_shape = input_ids.size
  elsif !inputs_embeds.nil?
    input_shape = inputs_embeds.size[...-1]
  else
    raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
  end

  device = !input_ids.nil? ? input_ids.device : inputs_embeds.device

  if attention_mask.nil?
    attention_mask = Torch.ones(input_shape, device: device)
  end
  if token_type_ids.nil?
    token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
  end

  embedding_output = @embeddings.(input_ids: input_ids, token_type_ids: token_type_ids, position_ids: position_ids, mask: attention_mask, inputs_embeds: inputs_embeds)

  encoder_outputs = @encoder.(embedding_output, attention_mask, output_hidden_states: true, output_attentions: output_attentions, return_dict: return_dict)
  encoded_layers = encoder_outputs[1]

  if @z_steps > 1
    hidden_states = encoded_layers[-2]
    layers = @z_steps.times.map { |_| @encoder.layer[-1] }
    query_states = encoded_layers[-1]
    rel_embeddings = @encoder.get_rel_embedding
    attention_mask = @encoder.get_attention_mask(attention_mask)
    rel_pos = @encoder.get_rel_pos(embedding_output)
    layers[1..].each do |layer|
      query_states = layer(hidden_states, attention_mask, output_attentions: false, query_states: query_states, relative_pos: rel_pos, rel_embeddings: rel_embeddings)
      encoded_layers << query_states
    end
  end

  sequence_output = encoded_layers[-1]

  if !return_dict
    return [sequence_output] + encoder_outputs[output_hidden_states ? 1 : 2..]
  end

  BaseModelOutput.new(last_hidden_state: sequence_output, hidden_states: output_hidden_states ? encoder_outputs.hidden_states : nil, attentions: encoder_outputs.attentions)
end

#get_input_embeddingsObject



739
740
741
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 739

def get_input_embeddings
  @embeddings.word_embeddings
end

#set_input_embeddings(new_embeddings) ⇒ Object



743
744
745
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 743

def set_input_embeddings(new_embeddings)
  @word_embeddings = new_embeddings
end