Class: Transformers::XlmRoberta::XLMRobertaModel

Inherits:
XLMRobertaPreTrainedModel show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from XLMRobertaPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config, add_pooling_layer: true) ⇒ XLMRobertaModel

self._no_split_modules = [“XLMRobertaEmbeddings”, “XLMRobertaLayer”]



609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 609

def initialize(config, add_pooling_layer: true)
  super(config)
  @config = config

  @embeddings = XLMRobertaEmbeddings.new(config)
  @encoder = XLMRobertaEncoder.new(config)

  @pooler = add_pooling_layer ? XLMRobertaPooler.new(config) : nil

  @attn_implementation = config._attn_implementation
  @position_embedding_type = config.position_embedding_type

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#_prune_heads(heads_to_prune) ⇒ Object



633
634
635
636
637
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 633

def _prune_heads(heads_to_prune)
  heads_to_prune.each do |layer, heads|
    @encoder.layer[layer].attention.prune_heads(heads)
  end
end

#forward(input_ids, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_values: nil, use_cache: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 639

def forward(
  input_ids,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_values: nil,
  use_cache: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
  output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  if @config.is_decoder
    use_cache = !use_cache.nil? ? use_cache : @config.use_cache
  else
    use_cache = false
  end

  if !input_ids.nil? && !inputs_embeds.nil?
    raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
  elsif !input_ids.nil?
    warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
    input_shape = input_ids.size
  elsif !inputs_embeds.nil?
    input_shape = inputs_embeds.size[...-1]
  else
    raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
  end

  batch_size, seq_length = input_shape
  device = !input_ids.nil? ? input_ids.device : inputs_embeds.device

  # past_key_values_length
  past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0

  if token_type_ids.nil?
    if @embeddings.respond_to?(:token_type_ids)
      buffered_token_type_ids = @embeddings.token_type_ids[0.., ...seq_length]
      buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
      token_type_ids = buffered_token_type_ids_expanded
    else
      token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
    end
  end

  embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, token_type_ids: token_type_ids, inputs_embeds: inputs_embeds, past_key_values_length: past_key_values_length)

  if attention_mask.nil?
    attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
  end

  use_sdpa_attention_masks = @attn_implementation == "sdpa" && @position_embedding_type == "absolute" && head_mask.nil? && !output_attentions

  # Expand the attention mask
  if use_sdpa_attention_masks && attention_mask.dim == 2
    # Expand the attention mask for SDPA.
    # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
    if @config.is_decoder
      extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, embedding_output, past_key_values_length)
    else
      extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(attention_mask, embedding_output.dtype, tgt_len: seq_length)
    end
  else
    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
    # ourselves in which case we just need to make it broadcastable to all heads.
    extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
  end

  # If a 2D or 3D attention mask is provided for the cross-attention
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  if @config.is_decoder && !encoder_hidden_states.nil?
    encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
    encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
    if encoder_attention_mask.nil?
      encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
    end

    if use_sdpa_attention_masks && encoder_attention_mask.dim == 2
      # Expand the attention mask for SDPA.
      # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
      encoder_extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length)
    else
      encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
    end
  else
    encoder_extended_attention_mask = nil
  end

  # Prepare head mask if needed
  # 1.0 in head_mask indicate we keep the head
  # attention_probs has shape bsz x n_heads x N x N
  # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  head_mask = get_head_mask(head_mask, @config.num_hidden_layers)

  encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_extended_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
  sequence_output = encoder_outputs[0]
  pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil

  if !return_dict
    return [sequence_output, pooled_output] + encoder_outputs[1..]
  end

  BaseModelOutputWithPoolingAndCrossAttentions.new(last_hidden_state: sequence_output, pooler_output: pooled_output, past_key_values: encoder_outputs.past_key_values, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions, cross_attentions: encoder_outputs.cross_attentions)
end

#get_input_embeddingsObject



625
626
627
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 625

def get_input_embeddings
  @embeddings.word_embeddings
end

#set_input_embeddings(value) ⇒ Object



629
630
631
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 629

def set_input_embeddings(value)
  @word_embeddings = value
end