Class: Transformers::Vit::ViTEmbeddings
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Vit::ViTEmbeddings
- Defined in:
- lib/transformers/models/vit/modeling_vit.rb
Instance Method Summary collapse
- #forward(pixel_values, bool_masked_pos: nil, interpolate_pos_encoding: false) ⇒ Object
-
#initialize(config, use_mask_token: false) ⇒ ViTEmbeddings
constructor
A new instance of ViTEmbeddings.
Constructor Details
#initialize(config, use_mask_token: false) ⇒ ViTEmbeddings
Returns a new instance of ViTEmbeddings.
18 19 20 21 22 23 24 25 26 27 28 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 18 def initialize(config, use_mask_token: false) super() @cls_token = Torch::NN::Parameter.new(Torch.randn(1, 1, config.hidden_size)) @mask_token = use_mask_token ? Torch::NN::Parameter.new(Torch.zeros(1, 1, config.hidden_size)) : nil @patch_embeddings = ViTPatchEmbeddings.new(config) num_patches = @patch_embeddings.num_patches @position_embeddings = Torch::NN::Parameter.new(Torch.randn(1, num_patches + 1, config.hidden_size)) @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob) @config = config end |
Instance Method Details
#forward(pixel_values, bool_masked_pos: nil, interpolate_pos_encoding: false) ⇒ Object
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 30 def forward( pixel_values, bool_masked_pos: nil, interpolate_pos_encoding: false ) batch_size, _num_channels, height, width = pixel_values.shape = @patch_embeddings.(pixel_values, interpolate_pos_encoding: interpolate_pos_encoding) if !bool_masked_pos.nil? seq_length = .shape[1] mask_tokens = @mask_token.(batch_size, seq_length, -1) # replace the masked visual tokens by mask_tokens mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) = * (1.0 - mask) + mask_tokens * mask end # add the [CLS] token to the embedded patch tokens cls_tokens = @cls_token.(batch_size, -1, -1) = Torch.cat([cls_tokens, ], dim: 1) # add positional encoding to each token if interpolate_pos_encoding = + @interpolate_pos_encoding.(, height, width) else = + @position_embeddings end = @dropout.() end |