Class: Transformers::XlmRoberta::XLMRobertaLayer

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ XLMRobertaLayer

Returns a new instance of XLMRobertaLayer.



419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 419

def initialize(config)
  super()
  @chunk_size_feed_forward = config.chunk_size_feed_forward
  @seq_len_dim = 1
  @attention = XLMRobertaAttention.new(config)
  @is_decoder = config.is_decoder
  @add_cross_attention = config.add_cross_attention
  if @add_cross_attention
    if !@is_decoder
      raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added"
    end
    @crossattention = XLMRobertaAttention.new(config, position_embedding_type: "absolute")
  end
  @intermediate = XLMRobertaIntermediate.new(config)
  @output = XLMRobertaOutput.new(config)
end

Instance Method Details

#feed_forward_chunk(attention_output) ⇒ Object



486
487
488
489
490
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 486

def feed_forward_chunk(attention_output)
  intermediate_output = @intermediate.(attention_output)
  layer_output = @output.(intermediate_output, attention_output)
  layer_output
end

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object



436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 436

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_value: nil,
  output_attentions: false
)
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil
  self_attention_outputs = @attention.(hidden_states, attention_mask:, head_mask:, output_attentions: output_attentions, past_key_value: self_attn_past_key_value)
  attention_output = self_attention_outputs[0]

  # if decoder, the last output is tuple of self-attn cache
  if @is_decoder
    outputs = self_attention_outputs[1...-1]
    present_key_value = self_attention_outputs[-1]
  else
    outputs = self_attention_outputs[1..]
  end

  cross_attn_present_key_value = nil
  if @is_decoder && !encoder_hidden_states.nil?
    if instance_variable_defined?(:@crossattention)
      raise ArgumentError, "If `encoder_hidden_states` are passed, #{self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
    end

    # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
    cross_attn_past_key_value = !past_key_value.nil? ? past_key_value[-2..] : nil
    cross_attention_outputs = @crossattention.(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, cross_attn_past_key_value, output_attentions)
    attention_output = cross_attention_outputs[0]
    outputs = outputs + cross_attention_outputs[1...-1]

    # add cross-attn cache to positions 3,4 of present_key_value tuple
    cross_attn_present_key_value = cross_attention_outputs[-1]
    present_key_value = present_key_value + cross_attn_present_key_value
  end

  layer_output = TorchUtils.apply_chunking_to_forward(method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output)
  outputs = [layer_output] + outputs

  # if decoder, return the attn key/values as the last output
  if @is_decoder
    outputs = outputs + [present_key_value]
  end

  outputs
end