Class: Transformers::Bert::BertModel

Inherits:
BertPreTrainedModel show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from BertPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config, add_pooling_layer: true) ⇒ BertModel

Returns a new instance of BertModel.



538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
# File 'lib/transformers/models/bert/modeling_bert.rb', line 538

def initialize(config, add_pooling_layer: true)
  super(config)
  @config = config

  @embeddings = BertEmbeddings.new(config)
  @encoder = BertEncoder.new(config)

  @pooler = add_pooling_layer ? BertPooler.new(config) : nil

  @attn_implementation = config._attn_implementation
  @position_embedding_type = config.position_embedding_type

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#_prune_heads(heads_to_prune) ⇒ Object



554
555
556
557
558
# File 'lib/transformers/models/bert/modeling_bert.rb', line 554

def _prune_heads(heads_to_prune)
  heads_to_prune.each do |layer, heads|
    @encoder.layer[layer].attention.prune_heads(heads)
  end
end

#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_values: nil, use_cache: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
# File 'lib/transformers/models/bert/modeling_bert.rb', line 560

def forward(
  input_ids: nil,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_values: nil,
  use_cache: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
  output_hidden_states = (
    !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
  )
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  if @config.is_decoder
    use_cache = !use_cache.nil? ? use_cache : @config.use_cache
  else
    use_cache = false
  end

  if !input_ids.nil? && !inputs_embeds.nil?
    raise ArgumentError, "You cannot specify both input_ids and inputs_embeds at the same time"
  elsif !input_ids.nil?
    # self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
    input_shape = input_ids.size
  elsif !inputs_embeds.nil?
    input_shape = inputs_embeds.size[...-1]
  else
    raise ArgumentError, "You have to specify either input_ids or inputs_embeds"
  end

  batch_size, seq_length = input_shape
  device = !input_ids.nil? ? input_ids.device : inputs_embeds.device

  # past_key_values_length
  past_key_values_length = !past_key_values.nil? ? past_key_values[0][0].shape[2] : 0

  if token_type_ids.nil?
    if @embeddings.token_type_ids
      buffered_token_type_ids = @embeddings.token_type_ids[0.., 0...seq_length]
      buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
      token_type_ids = buffered_token_type_ids_expanded
    else
      token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: device)
    end
  end

  embedding_output = @embeddings.(
    input_ids: input_ids,
    position_ids: position_ids,
    token_type_ids: token_type_ids,
    inputs_embeds: inputs_embeds,
    past_key_values_length: past_key_values_length
  )

  if attention_mask.nil?
    attention_mask = Torch.ones([batch_size, seq_length + past_key_values_length], device: device)
  end

  use_sdpa_attention_masks = (
    @attn_implementation == "sdpa" &&
    @position_embedding_type == "absolute" &&
    head_mask.nil? &&
    !output_attentions
  )

  # Expand the attention mask
  if use_sdpa_attention_masks
    raise Todo
  else
    # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
    # ourselves in which case we just need to make it broadcastable to all heads.
    extended_attention_mask = get_extended_attention_mask(attention_mask, input_shape)
  end

  # If a 2D or 3D attention mask is provided for the cross-attention
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  if @config.is_decoder && !encoder_hidden_states.nil?
    encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size
    encoder_hidden_shape = [encoder_batch_size, encoder_sequence_length]
    if encoder_attention_mask.nil?
      encoder_attention_mask = Torch.ones(encoder_hidden_shape, device: device)
    end

    if use_sdpa_attention_masks
      # Expand the attention mask for SDPA.
      # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
      encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
        encoder_attention_mask, embedding_output.dtype, tgt_len: seq_length
      )
    else
      encoder_extended_attention_mask = invert_attention_mask(encoder_attention_mask)
    end
  else
    encoder_extended_attention_mask = nil
  end

  # Prepare head mask if needed
  # 1.0 in head_mask indicate we keep the head
  # attention_probs has shape bsz x n_heads x N x N
  # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  head_mask = get_head_mask(head_mask, @config.num_hidden_layers)

  encoder_outputs = @encoder.(
    embedding_output,
    attention_mask: extended_attention_mask,
    head_mask: head_mask,
    encoder_hidden_states: encoder_hidden_states,
    encoder_attention_mask: encoder_extended_attention_mask,
    past_key_values: past_key_values,
    use_cache: use_cache,
    output_attentions: output_attentions,
    output_hidden_states: output_hidden_states,
    return_dict: return_dict
  )
  sequence_output = encoder_outputs[0]
  pooled_output = !@pooler.nil? ? @pooler.(sequence_output) : nil

  if !return_dict
    raise Todo
  end

  BaseModelOutputWithPoolingAndCrossAttentions.new(
    last_hidden_state: sequence_output,
    pooler_output: pooled_output,
    past_key_values: encoder_outputs.past_key_values,
    hidden_states: encoder_outputs.hidden_states,
    attentions: encoder_outputs.attentions,
    cross_attentions: encoder_outputs.cross_attentions
  )
end