Class: Transformers::Distilbert::DistilBertForMaskedLM
- Inherits:
-
DistilBertPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- DistilBertPreTrainedModel
- Transformers::Distilbert::DistilBertForMaskedLM
- Defined in:
- lib/transformers/models/distilbert/modeling_distilbert.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ DistilBertForMaskedLM
constructor
A new instance of DistilBertForMaskedLM.
Methods inherited from DistilBertPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ DistilBertForMaskedLM
Returns a new instance of DistilBertForMaskedLM.
424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 424 def initialize(config) super(config) @activation = get_activation(config.activation) @distilbert = DistilBertModel.new(config) @vocab_transform = Torch::NN::Linear.new(config.dim, config.dim) @vocab_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12) @vocab_projector = Torch::NN::Linear.new(config.dim, config.vocab_size) # Initialize weights and apply final processing post_init @mlm_loss_fct = Torch::NN::CrossEntropyLoss.new end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 440 def forward( input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict dlbrt_output = @distilbert.( input_ids: input_ids, attention_mask: attention_mask, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict ) hidden_states = dlbrt_output[0] # (bs, seq_length, dim) prediction_logits = @vocab_transform.(hidden_states) # (bs, seq_length, dim) prediction_logits = @activation.(prediction_logits) # (bs, seq_length, dim) prediction_logits = @vocab_layer_norm.(prediction_logits) # (bs, seq_length, dim) prediction_logits = @vocab_projector.(prediction_logits) # (bs, seq_length, vocab_size) mlm_loss = nil if !labels.nil? mlm_loss = @mlm_loss_fct.(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1)) end if !return_dict raise Todo end MaskedLMOutput.new( loss: mlm_loss, logits: prediction_logits, hidden_states: dlbrt_output.hidden_states, attentions: dlbrt_output.attentions ) end |