Class: Transformers::Mpnet::MPNetAttention
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Mpnet::MPNetAttention
- Defined in:
- lib/transformers/models/mpnet/modeling_mpnet.rb
Instance Method Summary collapse
- #forward(hidden_states, attention_mask: nil, head_mask: nil, position_bias: nil, output_attentions: false, **kwargs) ⇒ Object
-
#initialize(config) ⇒ MPNetAttention
constructor
A new instance of MPNetAttention.
- #prune_heads(heads) ⇒ Object
Constructor Details
#initialize(config) ⇒ MPNetAttention
Returns a new instance of MPNetAttention.
179 180 181 182 183 184 185 186 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 179 def initialize(config) super() @attn = MPNetSelfAttention.new(config) @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps) @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob) @pruned_heads = Set.new end |
Instance Method Details
#forward(hidden_states, attention_mask: nil, head_mask: nil, position_bias: nil, output_attentions: false, **kwargs) ⇒ Object
204 205 206 207 208 209 210 211 212 213 214 215 216 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 204 def forward( hidden_states, attention_mask: nil, head_mask: nil, position_bias: nil, output_attentions: false, **kwargs ) self_outputs = @attn.(hidden_states, attention_mask: attention_mask, head_mask: head_mask, position_bias: position_bias, output_attentions: output_attentions) attention_output = @LayerNorm.(@dropout.(self_outputs[0]) + hidden_states) outputs = [attention_output] + self_outputs[1..] outputs end |
#prune_heads(heads) ⇒ Object
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 188 def prune_heads(heads) if heads.length == 0 return end heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @attn.num_attention_heads, @attn.attention_head_size, @pruned_heads) @q = TorchUtils.prune_linear_layer(@attn.q, index) @k = TorchUtils.prune_linear_layer(@attn.k, index) @v = TorchUtils.prune_linear_layer(@attn.v, index) @o = TorchUtils.prune_linear_layer(@attn.o, index, dim: 1) @num_attention_heads = @attn.num_attention_heads - heads.length @all_head_size = @attn.attention_head_size * @attn.num_attention_heads @pruned_heads = @pruned_heads.union(heads) end |