Class: Transformers::Bert::BertLayer
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Bert::BertLayer
- Defined in:
- lib/transformers/models/bert/modeling_bert.rb
Instance Method Summary collapse
- #feed_forward_chunk(attention_output) ⇒ Object
- #forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
-
#initialize(config) ⇒ BertLayer
constructor
A new instance of BertLayer.
Constructor Details
#initialize(config) ⇒ BertLayer
Returns a new instance of BertLayer.
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 290 def initialize(config) super() @chunk_size_feed_forward = config.chunk_size_feed_forward @seq_len_dim = 1 @attention = BertAttention.new(config) @is_decoder = config.is_decoder @add_cross_attention = config.add_cross_attention if @add_cross_attention if !@is_decoder raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added" end @crossattention = BertAttention.new(config, position_embedding_type: "absolute") end @intermediate = BertIntermediate.new(config) @output = BertOutput.new(config) end |
Instance Method Details
#feed_forward_chunk(attention_output) ⇒ Object
353 354 355 356 357 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 353 def feed_forward_chunk(attention_output) intermediate_output = @intermediate.(attention_output) layer_output = @output.(intermediate_output, attention_output) layer_output end |
#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 307 def forward( hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false ) # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil self_attention_outputs = @attention.( hidden_states, attention_mask: attention_mask, head_mask: head_mask, output_attentions: output_attentions, past_key_value: self_attn_past_key_value ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if @is_decoder outputs = self_attention_outputs[1...-1] present_key_value = self_attention_outputs[-1] else outputs = self_attention_outputs[1..] # add self attentions if we output attention weights end _cross_attn_present_key_value = nil if @is_decoder && !encoder_hidden_states.nil? raise Todo end layer_output = TorchUtils.apply_chunking_to_forward( method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output ) outputs = [layer_output] + outputs # if decoder, return the attn key/values as the last output if @is_decoder outputs = outputs + [present_key_value] end outputs end |