Class: Transformers::Distilbert::DistilBertForMaskedLM

Inherits:
DistilBertPreTrainedModel show all
Defined in:
lib/transformers/models/distilbert/modeling_distilbert.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from DistilBertPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ DistilBertForMaskedLM

Returns a new instance of DistilBertForMaskedLM.



424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 424

def initialize(config)
  super(config)

  @activation = get_activation(config.activation)

  @distilbert = DistilBertModel.new(config)
  @vocab_transform = Torch::NN::Linear.new(config.dim, config.dim)
  @vocab_layer_norm = Torch::NN::LayerNorm.new(config.dim, eps: 1e-12)
  @vocab_projector = Torch::NN::Linear.new(config.dim, config.vocab_size)

  # Initialize weights and apply final processing
  post_init

  @mlm_loss_fct = Torch::NN::CrossEntropyLoss.new
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 440

def forward(
  input_ids: nil,
  attention_mask: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  dlbrt_output = @distilbert.(
    input_ids: input_ids,
    attention_mask: attention_mask,
    head_mask: head_mask,
    inputs_embeds: inputs_embeds,
    output_attentions: output_attentions,
    output_hidden_states: output_hidden_states,
    return_dict: return_dict
  )
  hidden_states = dlbrt_output[0]  # (bs, seq_length, dim)
  prediction_logits = @vocab_transform.(hidden_states)  # (bs, seq_length, dim)
  prediction_logits = @activation.(prediction_logits)  # (bs, seq_length, dim)
  prediction_logits = @vocab_layer_norm.(prediction_logits)  # (bs, seq_length, dim)
  prediction_logits = @vocab_projector.(prediction_logits)  # (bs, seq_length, vocab_size)

  mlm_loss = nil
  if !labels.nil?
    mlm_loss = @mlm_loss_fct.(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))
  end

  if !return_dict
    raise Todo
  end

  MaskedLMOutput.new(
    loss: mlm_loss,
    logits: prediction_logits,
    hidden_states: dlbrt_output.hidden_states,
    attentions: dlbrt_output.attentions
  )
end