Class: Transformers::Distilbert::Embeddings

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/distilbert/modeling_distilbert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ Embeddings

Returns a new instance of Embeddings.



18
19
20
21
22
23
24
25
26
27
28
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 18

def initialize(config)
  super()
  @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.dim, padding_idx: config.pad_token_id)
  @position_embeddings = Torch::NN::Embedding.new(config.max_position_embeddings, config.dim)

  @LayerNorm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
  @dropout = Torch::NN::Dropout.new(p: config.dropout)
  register_buffer(
    "position_ids", Torch.arange(config.max_position_embeddings).expand([1, -1]), persistent: false
  )
end

Instance Method Details

#forward(input_ids, input_embeds) ⇒ Object



30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 30

def forward(input_ids, input_embeds)
  if !input_ids.nil?
    input_embeds = @word_embeddings.(input_ids)  # (bs, max_seq_length, dim)
  end

  seq_length = input_embeds.size(1)

  # Setting the position-ids to the registered buffer in constructor, it helps
  # when tracing the model without passing position-ids, solves
  # isues similar to issue #5664
  if @position_ids
    position_ids = @position_ids[0.., 0...seq_length]
  else
    position_ids = Torch.arange(seq_length, dtype: :long, device: input_ids.device)  # (max_seq_length)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)  # (bs, max_seq_length)
  end

  position_embeddings = @position_embeddings.(position_ids)  # (bs, max_seq_length, dim)

  embeddings = input_embeds + position_embeddings  # (bs, max_seq_length, dim)
  embeddings = @LayerNorm.(embeddings)  # (bs, max_seq_length, dim)
  embeddings = @dropout.(embeddings)  # (bs, max_seq_length, dim)
  embeddings
end