Class: Transformers::Bert::BertLayer

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ BertLayer

Returns a new instance of BertLayer.



290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# File 'lib/transformers/models/bert/modeling_bert.rb', line 290

def initialize(config)
  super()
  @chunk_size_feed_forward = config.chunk_size_feed_forward
  @seq_len_dim = 1
  @attention = BertAttention.new(config)
  @is_decoder = config.is_decoder
  @add_cross_attention = config.add_cross_attention
  if @add_cross_attention
    if !@is_decoder
      raise ArgumentError, "#{self} should be used as a decoder model if cross attention is added"
    end
    @crossattention = BertAttention.new(config, position_embedding_type: "absolute")
  end
  @intermediate = BertIntermediate.new(config)
  @output = BertOutput.new(config)
end

Instance Method Details

#feed_forward_chunk(attention_output) ⇒ Object



353
354
355
356
357
# File 'lib/transformers/models/bert/modeling_bert.rb', line 353

def feed_forward_chunk(attention_output)
  intermediate_output = @intermediate.(attention_output)
  layer_output = @output.(intermediate_output, attention_output)
  layer_output
end

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_value: nil, output_attentions: false) ⇒ Object



307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# File 'lib/transformers/models/bert/modeling_bert.rb', line 307

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_value: nil,
  output_attentions: false
)
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  self_attn_past_key_value = !past_key_value.nil? ? past_key_value[...2] : nil
  self_attention_outputs = @attention.(
    hidden_states,
    attention_mask: attention_mask,
    head_mask: head_mask,
    output_attentions: output_attentions,
    past_key_value: self_attn_past_key_value
  )
  attention_output = self_attention_outputs[0]

  # if decoder, the last output is tuple of self-attn cache
  if @is_decoder
    outputs = self_attention_outputs[1...-1]
    present_key_value = self_attention_outputs[-1]
  else
    outputs = self_attention_outputs[1..]  # add self attentions if we output attention weights
  end

  _cross_attn_present_key_value = nil
  if @is_decoder && !encoder_hidden_states.nil?
    raise Todo
  end

  layer_output = TorchUtils.apply_chunking_to_forward(
    method(:feed_forward_chunk), @chunk_size_feed_forward, @seq_len_dim, attention_output
  )
  outputs = [layer_output] + outputs

  # if decoder, return the attn key/values as the last output
  if @is_decoder
    outputs = outputs + [present_key_value]
  end

  outputs
end