Class: Transformers::Vit::ViTAttention

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/vit/modeling_vit.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ ViTAttention

Returns a new instance of ViTAttention.



181
182
183
184
185
186
# File 'lib/transformers/models/vit/modeling_vit.rb', line 181

def initialize(config)
  super()
  @attention = ViTSelfAttention.new(config)
  @output = ViTSelfOutput.new(config)
  @pruned_heads = Set.new
end

Instance Method Details

#forward(hidden_states, head_mask: nil, output_attentions: false) ⇒ Object



192
193
194
195
196
197
198
199
200
201
202
203
# File 'lib/transformers/models/vit/modeling_vit.rb', line 192

def forward(
  hidden_states,
  head_mask: nil,
  output_attentions: false
)
  self_outputs = @attention.(hidden_states, head_mask: head_mask, output_attentions: output_attentions)

  attention_output = @output.(self_outputs[0], hidden_states)

  outputs = [attention_output] + self_outputs[1..]  # add attentions if we output them
  outputs
end

#prune_heads(heads) ⇒ Object

Raises:



188
189
190
# File 'lib/transformers/models/vit/modeling_vit.rb', line 188

def prune_heads(heads)
  raise Todo
end