Class: Transformers::XlmRoberta::XLMRobertaLayer
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::XlmRoberta::XLMRobertaLayer
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Instance Method Summary collapse
- #feed_forward_chunk(attention_output) ⇒ Object
- #forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
-
#initialize(config) ⇒ XLMRobertaLayer
constructor
A new instance of XLMRobertaLayer.
Constructor Details
#initialize(config) ⇒ XLMRobertaLayer
Returns a new instance of XLMRobertaLayer.
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 419 def initialize(config) super() @chunk_size_feed_forward = config.chunk_size_feed_forward @seq_len_dim = 1 @attention = XLMRobertaAttention.new(config) @is_decoder = config.is_decoder @add_cross_attention = config.add_cross_attention if @add_cross_attention if !@is_decoder raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added" end @crossattention = XLMRobertaAttention.new(config, position_embedding_type: "absolute") end @intermediate = XLMRobertaIntermediate.new(config) @output = XLMRobertaOutput.new(config) end |
Instance Method Details
#feed_forward_chunk(attention_output) ⇒ Object
486 487 488 489 490 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 486 def feed_forward_chunk(attention_output) intermediate_output = @intermediate.(attention_output) layer_output = @output.(intermediate_output, attention_output) layer_output end |
#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 436 def forward( hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false ) # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil self_attention_outputs = @attention.(hidden_states, attention_mask:, head_mask:, output_attentions: output_attentions, past_key_value: self_attn_past_key_value) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if @is_decoder outputs = self_attention_outputs[1...-1] present_key_value = self_attention_outputs[-1] else outputs = self_attention_outputs[1..] end cross_attn_present_key_value = nil if @is_decoder && !encoder_hidden_states.nil? if instance_variable_defined?(:@crossattention) raise ArgumentError, "If `encoder_hidden_states` are passed, #{self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" end # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple cross_attn_past_key_value = !past_key_value.nil? ? past_key_value[-2..] : nil cross_attention_outputs = @crossattention.(attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, cross_attn_past_key_value, output_attentions) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1...-1] # add cross-attn cache to positions 3,4 of present_key_value tuple cross_attn_present_key_value = cross_attention_outputs[-1] present_key_value = present_key_value + cross_attn_present_key_value end layer_output = TorchUtils.apply_chunking_to_forward(method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output) outputs = [layer_output] + outputs # if decoder, return the attn key/values as the last output if @is_decoder outputs = outputs + [present_key_value] end outputs end |