Class: Transformers::XlmRoberta::XLMRobertaAttention

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb

Instance Method Summary collapse

Constructor Details

#initialize(config, position_embedding_type: nil) ⇒ XLMRobertaAttention

Returns a new instance of XLMRobertaAttention.



343
344
345
346
347
348
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 343

def initialize(config, position_embedding_type: nil)
  super()
  @self = XLM_ROBERTA_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config, position_embedding_type: position_embedding_type)
  @output = XLMRobertaSelfOutput.new(config)
  @pruned_heads = Set.new
end

Instance Method Details

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object



368
369
370
371
372
373
374
375
376
377
378
379
380
381
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 368

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_value: nil,
  output_attentions: false
)
  self_outputs = @self.(hidden_states, attention_mask:, head_mask:, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:)
  attention_output = @output.(self_outputs[0], hidden_states)
  outputs = [attention_output] + self_outputs[1..]
  outputs
end

#prune_heads(heads) ⇒ Object



350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 350

def prune_heads(heads)
  if heads.length == 0
    return
  end
  heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @self.num_attention_heads, @self.attention_head_size, @pruned_heads)

  # Prune linear layers
  @query = TorchUtils.prune_linear_layer(@self.query, index)
  @key = TorchUtils.prune_linear_layer(@self.key, index)
  @value = TorchUtils.prune_linear_layer(@self.value, index)
  @dense = TorchUtils.prune_linear_layer(@output.dense, index, dim: 1)

  # Update hyper params and store pruned heads
  @num_attention_heads = @self.num_attention_heads - heads.length
  @all_head_size = @self.attention_head_size * @self.num_attention_heads
  @pruned_heads = @pruned_heads.union(heads)
end