Class: Transformers::Vit::ViTForImageClassification
- Inherits:
-
ViTPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- ViTPreTrainedModel
- Transformers::Vit::ViTForImageClassification
- Defined in:
- lib/transformers/models/vit/modeling_vit.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(pixel_values: nil, head_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, interpolate_pos_encoding: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ ViTForImageClassification
constructor
A new instance of ViTForImageClassification.
Methods inherited from ViTPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ ViTForImageClassification
Returns a new instance of ViTForImageClassification.
449 450 451 452 453 454 455 456 457 458 459 460 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 449 def initialize(config) super(config) @num_labels = config.num_labels @vit = ViTModel.new(config, add_pooling_layer: false) # Classifier head @classifier = config.num_labels > 0 ? Torch::NN::Linear.new(config.hidden_size, config.num_labels) : Torch::NN::Identity.new # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(pixel_values: nil, head_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, interpolate_pos_encoding: nil, return_dict: nil) ⇒ Object
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 |
# File 'lib/transformers/models/vit/modeling_vit.rb', line 462 def forward( pixel_values: nil, head_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, interpolate_pos_encoding: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @vit.( pixel_values: pixel_values, head_mask: head_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, interpolate_pos_encoding: interpolate_pos_encoding, return_dict: return_dict ) sequence_output = outputs[0] logits = @classifier.(sequence_output[0.., 0, 0..]) loss = nil if !labels.nil? raise Todo end if !return_dict raise Todo end ImageClassifierOutput.new( loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions ) end |