Class: Transformers::Vit::ViTEncoder
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Vit::ViTEncoder
- Defined in:
- lib/transformers/models/vit/modeling_vit.rb
Instance Method Summary collapse
- #forward(hidden_states, head_mask: nil, output_attentions: false, output_hidden_states: false, return_dict: true) ⇒ Object
-
#initialize(config) ⇒ ViTEncoder
constructor
A new instance of ViTEncoder.
Constructor Details
#initialize(config) ⇒ ViTEncoder
Returns a new instance of ViTEncoder.
288 289 290 291 292 293 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 288 def initialize(config) super() @config = config @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { ViTLayer.new(config) }) @gradient_checkpointing = false end |
Instance Method Details
#forward(hidden_states, head_mask: nil, output_attentions: false, output_hidden_states: false, return_dict: true) ⇒ Object
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 295 def forward( hidden_states, head_mask: nil, output_attentions: false, output_hidden_states: false, return_dict: true ) all_hidden_states = output_hidden_states ? [] : nil all_self_attentions = output_attentions ? [] : nil @layer.each_with_index do |layer_module, i| if output_hidden_states all_hidden_states = all_hidden_states + [hidden_states] end layer_head_mask = !head_mask.nil? ? head_mask[i] : nil if @gradient_checkpointing && @training raise Todo else layer_outputs = layer_module.(hidden_states, head_mask: layer_head_mask, output_attentions: output_attentions) end hidden_states = layer_outputs[0] if output_attentions all_self_attentions = all_self_attentions + [layer_outputs[1]] end end if output_hidden_states all_hidden_states = all_hidden_states + [hidden_states] end if !return_dict raise Todo end BaseModelOutput.new( last_hidden_state: hidden_states, hidden_states: all_hidden_states, attentions: all_self_attentions ) end |