Class: Transformers::Vit::ViTPatchEmbeddings
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Vit::ViTPatchEmbeddings
- Defined in:
- lib/transformers/models/vit/modeling_vit.rb
Instance Attribute Summary collapse
-
#num_patches ⇒ Object
readonly
Returns the value of attribute num_patches.
Instance Method Summary collapse
- #forward(pixel_values, interpolate_pos_encoding: false) ⇒ Object
-
#initialize(config) ⇒ ViTPatchEmbeddings
constructor
A new instance of ViTPatchEmbeddings.
Constructor Details
#initialize(config) ⇒ ViTPatchEmbeddings
Returns a new instance of ViTPatchEmbeddings.
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 66 def initialize(config) super() image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size image_size = image_size.is_a?(Enumerable) ? image_size : [image_size, image_size] patch_size = patch_size.is_a?(Enumerable) ? patch_size : [patch_size, patch_size] num_patches = image_size[1].div(patch_size[1]) * image_size[0].div(patch_size[0]) @image_size = image_size @patch_size = patch_size @num_channels = num_channels @num_patches = num_patches @projection = Torch::NN::Conv2d.new(num_channels, hidden_size, patch_size, stride: patch_size) end |
Instance Attribute Details
#num_patches ⇒ Object (readonly)
Returns the value of attribute num_patches.
64 65 66 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 64 def num_patches @num_patches end |
Instance Method Details
#forward(pixel_values, interpolate_pos_encoding: false) ⇒ Object
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 82 def forward(pixel_values, interpolate_pos_encoding: false) _batch_size, num_channels, height, width = pixel_values.shape if num_channels != @num_channels raise ArgumentError, "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + " Expected #{@num_channels} but got #{num_channels}." end if !interpolate_pos_encoding if height != @image_size[0] || width != @image_size[1] raise ArgumentError, "Input image size (#{height}*#{width}) doesn't match model" + " (#{@image_size[0]}*#{@image_size[1]})." end end = @projection.(pixel_values).flatten(2).transpose(1, 2) end |