Class: Transformers::Bert::BertForMaskedLM

Inherits:
BertPreTrainedModel show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from BertPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ BertForMaskedLM

Returns a new instance of BertForMaskedLM.



702
703
704
705
706
707
708
709
710
711
712
713
714
# File 'lib/transformers/models/bert/modeling_bert.rb', line 702

def initialize(config)
  super(config)

  if config.is_decoder
    Transformers.logger.warn(
      "If you want to use `BertForMaskedLM` make sure `config.is_decoder: false` for " +
      "bi-directional self-attention."
    )
  end

  @bert = BertModel.new(config, add_pooling_layer: false)
  @cls = BertOnlyMLMHead.new(config)
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
# File 'lib/transformers/models/bert/modeling_bert.rb', line 716

def forward(
  input_ids: nil,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  encoder_hidden_states: nil,
  encoder_attention_mask: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : config.use_return_dict

  outputs = @bert.(
    input_ids: input_ids,
    attention_mask: attention_mask,
    token_type_ids: token_type_ids,
    position_ids: position_ids,
    head_mask: head_mask,
    inputs_embeds: inputs_embeds,
    encoder_hidden_states: encoder_hidden_states,
    encoder_attention_mask: encoder_attention_mask,
    output_attentions: output_attentions,
    output_hidden_states: output_hidden_states,
    return_dict: return_dict
  )

  sequence_output = outputs[0]
  prediction_scores = @cls.(sequence_output)

  masked_lm_loss = nil
  if !labels.nil?
    raise Todo
  end

  if !return_dict
    raise Todo
  end

  MaskedLMOutput.new(
    loss: masked_lm_loss,
    logits: prediction_scores,
    hidden_states: outputs.hidden_states,
    attentions: outputs.attentions
  )
end