Class: Transformers::Mpnet::MPNetAttention

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ MPNetAttention

Returns a new instance of MPNetAttention.



179
180
181
182
183
184
185
186
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 179

def initialize(config)
  super()
  @attn = MPNetSelfAttention.new(config)
  @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
  @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)

  @pruned_heads = Set.new
end

Instance Method Details

#forward(hidden_states, attention_mask: nil, head_mask: nil, position_bias: nil, output_attentions: false, **kwargs) ⇒ Object



204
205
206
207
208
209
210
211
212
213
214
215
216
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 204

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  position_bias: nil,
  output_attentions: false,
  **kwargs
)
  self_outputs = @attn.(hidden_states, attention_mask: attention_mask, head_mask: head_mask, position_bias: position_bias, output_attentions: output_attentions)
  attention_output = @LayerNorm.(@dropout.(self_outputs[0]) + hidden_states)
  outputs = [attention_output] + self_outputs[1..]
  outputs
end

#prune_heads(heads) ⇒ Object



188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 188

def prune_heads(heads)
  if heads.length == 0
    return
  end
  heads, index = TorchUtils.find_pruneable_heads_and_indices(heads, @attn.num_attention_heads, @attn.attention_head_size, @pruned_heads)

  @q = TorchUtils.prune_linear_layer(@attn.q, index)
  @k = TorchUtils.prune_linear_layer(@attn.k, index)
  @v = TorchUtils.prune_linear_layer(@attn.v, index)
  @o = TorchUtils.prune_linear_layer(@attn.o, index, dim: 1)

  @num_attention_heads = @attn.num_attention_heads - heads.length
  @all_head_size = @attn.attention_head_size * @attn.num_attention_heads
  @pruned_heads = @pruned_heads.union(heads)
end