Class: Transformers::Vit::ViTModel

Inherits:
ViTPreTrainedModel show all
Defined in:
lib/transformers/models/vit/modeling_vit.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from ViTPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config, add_pooling_layer: true, use_mask_token: false) ⇒ ViTModel

Returns a new instance of ViTModel.



351
352
353
354
355
356
357
358
359
360
361
362
363
# File 'lib/transformers/models/vit/modeling_vit.rb', line 351

def initialize(config, add_pooling_layer: true, use_mask_token: false)
  super(config)
  @config = config

  @embeddings = ViTEmbeddings.new(config, use_mask_token: use_mask_token)
  @encoder = ViTEncoder.new(config)

  @layernorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
  @pooler = add_pooling_layer ? ViTPooler.new(config) : nil

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#_prune_heads(heads_to_prune) ⇒ Object



365
366
367
368
369
# File 'lib/transformers/models/vit/modeling_vit.rb', line 365

def _prune_heads(heads_to_prune)
  heads_to_prune.each do |layer, heads|
    @encoder.layer[layer].attention.prune_heads(heads)
  end
end

#forward(pixel_values: nil, bool_masked_pos: nil, head_mask: nil, output_attentions: nil, output_hidden_states: nil, interpolate_pos_encoding: nil, return_dict: nil) ⇒ Object



371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
# File 'lib/transformers/models/vit/modeling_vit.rb', line 371

def forward(
  pixel_values: nil,
  bool_masked_pos: nil,
  head_mask: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  interpolate_pos_encoding: nil,
  return_dict: nil
)
  output_attentions = !output_attentions.nil? ? output_attentions : @config.output_attentions
  output_hidden_states = (
    !output_hidden_states.nil? ? output_hidden_states : @config.output_hidden_states
  )
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  if pixel_values.nil?
    raise ArgumentError, "You have to specify pixel_values"
  end

  # Prepare head mask if needed
  # 1.0 in head_mask indicate we keep the head
  # attention_probs has shape bsz x n_heads x N x N
  # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  head_mask = get_head_mask(head_mask, @config.num_hidden_layers)

  # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
  expected_dtype = @embeddings.patch_embeddings.projection.weight.dtype
  if pixel_values.dtype != expected_dtype
    pixel_values = pixel_values.to(expected_dtype)
  end

  embedding_output = @embeddings.(
    pixel_values, bool_masked_pos: bool_masked_pos, interpolate_pos_encoding: interpolate_pos_encoding
  )

  encoder_outputs = @encoder.(
    embedding_output,
    head_mask: head_mask,
    output_attentions: output_attentions,
    output_hidden_states: output_hidden_states,
    return_dict: return_dict
  )
  sequence_output = encoder_outputs[0]
  sequence_output = @layernorm.(sequence_output)
  pooled_output = @pooler ? @pooler.(sequence_output) : nil

  if !return_dict
    raise Todo
  end

  BaseModelOutputWithPooling.new(
    last_hidden_state: sequence_output,
    pooler_output: pooled_output,
    hidden_states: encoder_outputs.hidden_states,
    attentions: encoder_outputs.attentions
  )
end