Class: Transformers::XlmRoberta::XLMRobertaAttention
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::XlmRoberta::XLMRobertaAttention
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Instance Method Summary collapse
- #forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
-
#initialize(config, position_embedding_type: nil) ⇒ XLMRobertaAttention
constructor
A new instance of XLMRobertaAttention.
- #prune_heads(heads) ⇒ Object
Constructor Details
#initialize(config, position_embedding_type: nil) ⇒ XLMRobertaAttention
Returns a new instance of XLMRobertaAttention.
343 344 345 346 347 348 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 343 def initialize(config, position_embedding_type: nil) super() @self = XLM_ROBERTA_SELF_ATTENTION_CLASSES.fetch(config._attn_implementation).new(config, position_embedding_type: ) @output = XLMRobertaSelfOutput.new(config) @pruned_heads = Set.new end |
Instance Method Details
#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
368 369 370 371 372 373 374 375 376 377 378 379 380 381 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 368 def forward( hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false ) self_outputs = @self.(hidden_states, attention_mask:, head_mask:, encoder_hidden_states:, encoder_attention_mask:, past_key_value:, output_attentions:) attention_output = @output.(self_outputs[0], hidden_states) outputs = [attention_output] + self_outputs[1..] outputs end |
#prune_heads(heads) ⇒ Object
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 350 def prune_heads(heads) if heads.length == 0 return end heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @self.num_attention_heads, @self.attention_head_size, @pruned_heads) # Prune linear layers @query = TorchUtils.prune_linear_layer(@self.query, index) @key = TorchUtils.prune_linear_layer(@self.key, index) @value = TorchUtils.prune_linear_layer(@self.value, index) @dense = TorchUtils.prune_linear_layer(@output.dense, index, dim: 1) # Update hyper params and store pruned heads @num_attention_heads = @self.num_attention_heads - heads.length @all_head_size = @self.attention_head_size * @self.num_attention_heads @pruned_heads = @pruned_heads.union(heads) end |