Class: Transformers::Distilbert::Embeddings
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Distilbert::Embeddings
- Defined in:
- lib/transformers/models/distilbert/modeling_distilbert.rb
Instance Method Summary collapse
- #forward(input_ids, input_embeds) ⇒ Object
-
#initialize(config) ⇒ Embeddings
constructor
A new instance of Embeddings.
Constructor Details
#initialize(config) ⇒ Embeddings
Returns a new instance of Embeddings.
18 19 20 21 22 23 24 25 26 27 28 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 18 def initialize(config) super() @word_embeddings = Torch::NN::Embedding.new(config.vocab_size, config.dim, padding_idx: config.pad_token_id) @position_embeddings = Torch::NN::Embedding.new(config., config.dim) @LayerNorm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12) @dropout = Torch::NN::Dropout.new(p: config.dropout) register_buffer( "position_ids", Torch.arange(config.).([1, -1]), persistent: false ) end |
Instance Method Details
#forward(input_ids, input_embeds) ⇒ Object
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 30 def forward(input_ids, ) if !input_ids.nil? = @word_embeddings.(input_ids) # (bs, max_seq_length, dim) end seq_length = .size(1) # Setting the position-ids to the registered buffer in constructor, it helps # when tracing the model without passing position-ids, solves # isues similar to issue #5664 if @position_ids position_ids = @position_ids[0.., 0...seq_length] else position_ids = Torch.arange(seq_length, dtype: :long, device: input_ids.device) # (max_seq_length) position_ids = position_ids.unsqueeze(0).(input_ids) # (bs, max_seq_length) end = @position_embeddings.(position_ids) # (bs, max_seq_length, dim) = + # (bs, max_seq_length, dim) = @LayerNorm.() # (bs, max_seq_length, dim) = @dropout.() # (bs, max_seq_length, dim) end |