Class: Transformers::Vit::ViTPatchEmbeddings

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/vit/modeling_vit.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ ViTPatchEmbeddings

Returns a new instance of ViTPatchEmbeddings.



66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# File 'lib/transformers/models/vit/modeling_vit.rb', line 66

def initialize(config)
  super()
  image_size, patch_size = config.image_size, config.patch_size
  num_channels, hidden_size = config.num_channels, config.hidden_size

  image_size = image_size.is_a?(Enumerable) ? image_size : [image_size, image_size]
  patch_size = patch_size.is_a?(Enumerable) ? patch_size : [patch_size, patch_size]
  num_patches = image_size[1].div(patch_size[1]) * image_size[0].div(patch_size[0])
  @image_size = image_size
  @patch_size = patch_size
  @num_channels = num_channels
  @num_patches = num_patches

  @projection = Torch::NN::Conv2d.new(num_channels, hidden_size, patch_size, stride: patch_size)
end

Instance Attribute Details

#num_patchesObject (readonly)

Returns the value of attribute num_patches.



64
65
66
# File 'lib/transformers/models/vit/modeling_vit.rb', line 64

def num_patches
  @num_patches
end

Instance Method Details

#forward(pixel_values, interpolate_pos_encoding: false) ⇒ Object



82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# File 'lib/transformers/models/vit/modeling_vit.rb', line 82

def forward(pixel_values, interpolate_pos_encoding: false)
  _batch_size, num_channels, height, width = pixel_values.shape
  if num_channels != @num_channels
    raise ArgumentError,
      "Make sure that the channel dimension of the pixel values match with the one set in the configuration." +
      " Expected #{@num_channels} but got #{num_channels}."
  end
  if !interpolate_pos_encoding
    if height != @image_size[0] || width != @image_size[1]
      raise ArgumentError,
        "Input image size (#{height}*#{width}) doesn't match model" +
        " (#{@image_size[0]}*#{@image_size[1]})."
    end
  end
  embeddings = @projection.(pixel_values).flatten(2).transpose(1, 2)
  embeddings
end