Class: Transformers::Bert::BertPredictionHeadTransform

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Method Summary collapse

Constructor Details

#initialize(config) ⇒ BertPredictionHeadTransform

Returns a new instance of BertPredictionHeadTransform.



458
459
460
461
462
463
464
465
466
467
# File 'lib/transformers/models/bert/modeling_bert.rb', line 458

def initialize(config)
  super()
  @dense = Torch::NN::Linear.new(config.hidden_size, config.hidden_size)
  if config.hidden_act.is_a?(String)
    @transform_act_fn = ACT2FN[config.hidden_act]
  else
    @transform_act_fn = config.hidden_act
  end
  @LayerNorm = Torch::NN::LayerNorm.new(config.hidden_size, eps: config.layer_norm_eps)
end

Instance Method Details

#forward(hidden_states) ⇒ Object



469
470
471
472
473
474
# File 'lib/transformers/models/bert/modeling_bert.rb', line 469

def forward(hidden_states)
  hidden_states = @dense.(hidden_states)
  hidden_states = @transform_act_fn.(hidden_states)
  hidden_states = @LayerNorm.(hidden_states)
  hidden_states
end