Class: Transformers::DebertaV2::DebertaV2Embeddings

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ DebertaV2Embeddings

Returns a new instance of DebertaV2Embeddings.



617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 617

def initialize(config)
  super()
  pad_token_id = config.getattr("pad_token_id", 0)
  @embedding_size = config.getattr("embedding_size", config.hidden_size)
  @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, @embedding_size, padding_idx: pad_token_id)

  @position_biased_input = config.getattr("position_biased_input", true)
  if !@position_biased_input
    @position_embeddings = nil
  else
    @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, @embedding_size)
  end

  if config.type_vocab_size > 0
    @token_type_embeddings = Torch::NN::Embedding.new(config.type_vocab_size, @embedding_size)
  end

  if @embedding_size != config.hidden_size
    @embed_proj = Torch::NN::Linear.new(@embedding_size, config.hidden_size, bias: false)
  end
  @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
  @dropout = StableDropout.new(config.hidden_dropout_prob)
  @config = config

  # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  register_buffer("position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false)
end

Instance Method Details

#forward(input_ids: nil, token_type_ids: nil, position_ids: nil, mask: nil, inputs_embeds: nil) ⇒ Object



645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 645

def forward(input_ids: nil, token_type_ids: nil, position_ids: nil, mask: nil, inputs_embeds: nil)
  if !input_ids.nil?
    input_shape = input_ids.size
  else
    input_shape = inputs_embeds.size[...-1]
  end

  seq_length = input_shape[1]

  if position_ids.nil?
    position_ids = @position_ids[0.., ...seq_length]
  end

  if token_type_ids.nil?
    token_type_ids = Torch.zeros(input_shape, dtype: Torch.long, device: @position_ids.device)
  end

  if inputs_embeds.nil?
    inputs_embeds = @word_embeddings.(input_ids)
  end

  if !@position_embeddings.nil?
    position_embeddings = @position_embeddings.(position_ids.long)
  else
    position_embeddings = Torch.zeros_like(inputs_embeds)
  end

  embeddings = inputs_embeds
  if @position_biased_input
    embeddings += position_embeddings
  end
  if @config.type_vocab_size > 0
    token_type_embeddings = @token_type_embeddings.(token_type_ids)
    embeddings += token_type_embeddings
  end

  if @embedding_size != @config.hidden_size
    embeddings = @embed_proj.(embeddings)
  end

  embeddings = @LayerNorm.(embeddings)

  if !mask.nil?
    if mask.dim != embeddings.dim
      if mask.dim == 4
        mask = mask.squeeze(1).squeeze(1)
      end
      mask = mask.unsqueeze(2)
    end
    mask = mask.to(embeddings.dtype)

    embeddings = embeddings * mask
  end

  embeddings = @dropout.(embeddings)
  embeddings
end