Class: Transformers::Bert::BertEncoder

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ BertEncoder

Returns a new instance of BertEncoder.



361
362
363
364
365
366
# File 'lib/transformers/models/bert/modeling_bert.rb', line 361

def initialize(config)
  super()
  @config = config
  @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { BertLayer.new(config) })
  @gradient_checkpointing = false
end

Instance Method Details

#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_values: nil, use_cache: nil, output_attentions: false, output_hidden_states: false, return_dict: true) ⇒ Object



368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
# File 'lib/transformers/models/bert/modeling_bert.rb', line 368

def forward(
  hidden_states,
  attention_mask: nil,
  head_mask: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  past_key_values: nil,
  use_cache: nil,
  output_attentions: false,
  output_hidden_states: false,
  return_dict: true
)
  all_hidden_states = output_hidden_states ? [] : nil
  all_self_attentions = output_attentions ? [] : nil
  all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil

  if @gradient_checkpointing && @raining
    raise Todo
  end

  next_decoder_cache = use_cache ? [] : nil
  @layer.each_with_index do |layer_module, i|
    if output_hidden_states
      all_hidden_states = all_hidden_states + [hidden_states]
    end

    layer_head_mask = !head_mask.nil? ? head_mask[i] : nil
    past_key_value = !past_key_values.nil? ? past_key_values[i] : nil

    if @gradient_checkpointing && @training
      raise Todo
    else
      layer_outputs = layer_module.(
        hidden_states,
        attention_mask: attention_mask,
        head_mask: layer_head_mask,
        encoder_hidden_states: encoder_hidden_states,
        encoder_attention_mask: encoder_attention_mask,
        past_key_value: past_key_value,
        output_attentions: output_attentions
      )
    end

    hidden_states = layer_outputs[0]
    if use_cache
      next_decoder_cache += [layer_outputs[-1]]
    end
    if output_attentions
      all_self_attentions = all_self_attentions + [layer_outputs[1]]
      if @config.add_cross_attention
        all_cross_attentions = all_cross_attentions + [layer_outputs[2]]
      end
    end
  end

  if output_hidden_states
    all_hidden_states = all_hidden_states + [hidden_states]
  end

  if !return_dict
    raise Todo
  end
  BaseModelOutputWithPastAndCrossAttentions.new(
    last_hidden_state: hidden_states,
    past_key_values: next_decoder_cache,
    hidden_states: all_hidden_states,
    attentions: all_self_attentions,
    cross_attentions: all_cross_attentions
  )
end