639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
|
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 639
def forward(
input_ids,
attention_mask: nil,
token_type_ids: nil,
position_ids: nil,
head_mask: nil,
inputs_embeds: nil,
encoder_hidden_states: nil,
encoder_attention_mask: nil,
past_key_values: nil,
use_cache: nil,
output_attentions: nil,
output_hidden_states: nil,
return_dict: nil
)
output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
output_hidden_states = !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
if @config.is_decoder
use_cache = !use_cache.nil? ? use_cache : @config.use_cache
else
use_cache = false
end
if !input_ids.nil? && !inputs_embeds.nil?
raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
elsif !input_ids.nil?
warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
input_shape = input_ids.size
elsif !inputs_embeds.nil?
input_shape = inputs_embeds.size[...-1]
else
raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
end
batch_size, seq_length = input_shape
device = !input_ids.nil? ? input_ids.device : inputs_embeds.device
past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0
if token_type_ids.nil?
if @embeddings.respond_to?(:token_type_ids)
buffered_token_type_ids = @embeddings.token_type_ids[0.., ...seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else
token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
end
end
embedding_output = @embeddings.(input_ids: input_ids, position_ids: position_ids, token_type_ids: token_type_ids, inputs_embeds: inputs_embeds, past_key_values_length: past_key_values_length)
if attention_mask.nil?
attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
end
use_sdpa_attention_masks = @attn_implementation == "sdpa" && @position_embedding_type == "absolute" && head_mask.nil? && !output_attentions
if use_sdpa_attention_masks && attention_mask.dim == 2
if @config.is_decoder
extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, embedding_output, past_key_values_length)
else
extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(attention_mask, embedding_output.dtype, tgt_len: seq_length)
end
else
extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
end
if @config.is_decoder && !encoder_hidden_states.nil?
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
if encoder_attention_mask.nil?
encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
end
if use_sdpa_attention_masks && encoder_attention_mask.dim == 2
encoder_extended_attention_mask = ModelingAttnMaskUtils._prepare_4d_attention_mask_for_sdpa(encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length)
else
encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
end
else
encoder_extended_attention_mask = nil
end
head_mask = get_head_mask(head_mask, @config.num_hidden_layers)
encoder_outputs = @encoder.(embedding_output, attention_mask: extended_attention_mask, head_mask: head_mask, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_extended_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
sequence_output = encoder_outputs[0]
pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil
if !return_dict
return [sequence_output, pooled_output] + encoder_outputs[1..]
end
BaseModelOutputWithPoolingAndCrossAttentions.new(last_hidden_state: sequence_output, pooler_output: pooled_output, past_key_values: encoder_outputs.past_key_values, hidden_states: encoder_outputs.hidden_states, attentions: encoder_outputs.attentions, cross_attentions: encoder_outputs.cross_attentions)
end
|