Class: Transformers::Vit::ViTForImageClassification

Inherits:
ViTPreTrainedModel show all
Defined in:
lib/transformers/models/vit/modeling_vit.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from ViTPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ ViTForImageClassification

Returns a new instance of ViTForImageClassification.



449
450
451
452
453
454
455
456
457
458
459
460
# File 'lib/transformers/models/vit/modeling_vit.rb', line 449

def initialize(config)
  super(config)

  @num_labels = config.num_labels
  @vit = ViTModel.new(config, add_pooling_layer: false)

  # Classifier head
  @classifier = config.num_labels > 0 ? Torch::NN::Linear.new(config.hidden_size, config.num_labels) : Torch::NN::Identity.new

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(pixel_values: nil, head_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, interpolate_pos_encoding: nil, return_dict: nil) ⇒ Object



462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
# File 'lib/transformers/models/vit/modeling_vit.rb', line 462

def forward(
  pixel_values: nil,
  head_mask: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  interpolate_pos_encoding: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  outputs = @vit.(
    pixel_values: pixel_values,
    head_mask: head_mask,
    output_attentions: output_attentions,
    output_hidden_states: output_hidden_states,
    interpolate_pos_encoding: interpolate_pos_encoding,
    return_dict: return_dict
  )

  sequence_output = outputs[0]

  logits = @classifier.(sequence_output[0.., 0, 0..])

  loss = nil
  if !labels.nil?
    raise Todo
  end

  if !return_dict
    raise Todo
  end

  ImageClassifierOutput.new(
    loss: loss,
    logits: logits,
    hidden_states: outputs.hidden_states,
    attentions: outputs.attentions
  )
end