Class: Transformers::XlmRoberta::XLMRobertaEncoder

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ XLMRobertaEncoder

Returns a new instance of XLMRobertaEncoder.



494
495
496
497
498
499
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 494

def initialize(config)
  super()
  @config = config
  @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { |_| XLMRobertaLayer.new(config) })
  @gradient_checkpointing = false
end

Instance Method Details

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_values: nil, use_cache: nil, output_attentions: false, output_hidden_states: false, return_dict: true) ⇒ Object



501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 501

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_values: nil,
  use_cache: nil,
  output_attentions: false,
  output_hidden_states: false,
  return_dict: true
)
  all_hidden_states = output_hidden_states ? [] : nil
  all_self_attentions = output_attentions ? [] : nil
  all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil

  if @gradient_checkpointing && @training
    if use_cache
      Transformers.logger.warn("`use_cache: true` is incompatible with gradient checkpointing. Setting `use_cache: false`...")
      use_cache = false
    end
  end

  next_decoder_cache = use_cache ? [] : nil
  @layer.each_with_index do |layer_module, i|
    if output_hidden_states
      all_hidden_states = all_hidden_states + [hidden_states]
    end

    layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
    past_key_value = !past_key_values.nil? ? past_key_values[i] : nil

    if @gradient_checkpointing && @training
      layer_outputs = _gradient_checkpointing_func(layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
    else
      layer_outputs = layer_module.(hidden_states, attention_mask:, head_mask: layer_head_mask, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:)
    end

    hidden_states = layer_outputs[0]
    if use_cache
      next_decoder_cache += [layer_outputs[-1]]
    end
    if output_attentions
      all_self_attentions = all_self_attentions + [layer_outputs[1]]
      if @config.add_cross_attention
        all_cross_attentions = all_cross_attentions + [layer_outputs[2]]
      end
    end
  end

  if output_hidden_states
    all_hidden_states = all_hidden_states + [hidden_states]
  end

  if !return_dict
    return Array([hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions].select { |v| !v.nil? })
  end
  BaseModelOutputWithPastAndCrossAttentions.new(last_hidden_state: hidden_states, past_key_values: next_decoder_cache, hidden_states: all_hidden_states, attentions: all_self_attentions, cross_attentions: all_cross_attentions)
end