Class: Transformers::Bert::BertEncoder
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- Transformers::Bert::BertEncoder
- Defined in:
- lib/transformers/models/bert/modeling_bert.rb
Instance Method Summary collapse
- #forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_values: nil, use_cache: nil, output_attentions: false, output_hidden_states: false, return_dict: true) ⇒ Object
-
#initialize(config) ⇒ BertEncoder
constructor
A new instance of BertEncoder.
Constructor Details
#initialize(config) ⇒ BertEncoder
Returns a new instance of BertEncoder.
361 362 363 364 365 366 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 361 def initialize(config) super() @config = config @layer = Torch::NN::ModuleList.new(config.num_hidden_layers.times.map { BertLayer.new(config) }) @gradient_checkpointing = false end |
Instance Method Details
#forward(hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_values: nil, use_cache: nil, output_attentions: false, output_hidden_states: false, return_dict: true) ⇒ Object
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 368 def forward( hidden_states, attention_mask: nil, head_mask: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, past_key_values: nil, use_cache: nil, output_attentions: false, output_hidden_states: false, return_dict: true ) all_hidden_states = output_hidden_states ? [] : nil all_self_attentions = output_attentions ? [] : nil all_cross_attentions = output_attentions && @config.add_cross_attention ? [] : nil if @gradient_checkpointing && @raining raise Todo end next_decoder_cache = use_cache ? [] : nil @layer.each_with_index do |layer_module, i| if output_hidden_states all_hidden_states = all_hidden_states + [hidden_states] end layer_head_mask = !head_mask.nil? ? head_mask[i] : nil past_key_value = !past_key_values.nil? ? past_key_values[i] : nil if @gradient_checkpointing && @training raise Todo else layer_outputs = layer_module.( hidden_states, attention_mask: attention_mask, head_mask: layer_head_mask, encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, past_key_value: past_key_value, output_attentions: output_attentions ) end hidden_states = layer_outputs[0] if use_cache next_decoder_cache += [layer_outputs[-1]] end if output_attentions all_self_attentions = all_self_attentions + [layer_outputs[1]] if @config.add_cross_attention all_cross_attentions = all_cross_attentions + [layer_outputs[2]] end end end if output_hidden_states all_hidden_states = all_hidden_states + [hidden_states] end if !return_dict raise Todo end BaseModelOutputWithPastAndCrossAttentions.new( last_hidden_state: hidden_states, past_key_values: next_decoder_cache, hidden_states: all_hidden_states, attentions: all_self_attentions, cross_attentions: all_cross_attentions ) end |