Class: Transformers::Bert::BertForMaskedLM
- Inherits:
-
BertPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- BertPreTrainedModel
- Transformers::Bert::BertForMaskedLM
- Defined in:
- lib/transformers/models/bert/modeling_bert.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ BertForMaskedLM
constructor
A new instance of BertForMaskedLM.
Methods inherited from BertPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ BertForMaskedLM
Returns a new instance of BertForMaskedLM.
702 703 704 705 706 707 708 709 710 711 712 713 714 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 702 def initialize(config) super(config) if config.is_decoder Transformers.logger.warn( "If you want to use `BertForMaskedLM` make sure `config.is_decoder: false` for " + "bi-directional self-attention." ) end @bert = BertModel.new(config, add_pooling_layer: false) @cls = BertOnlyMLMHead.new(config) end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 716 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : config.use_return_dict outputs = @bert.( input_ids: input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict ) sequence_output = outputs[0] prediction_scores = @cls.(sequence_output) masked_lm_loss = nil if !labels.nil? raise Todo end if !return_dict raise Todo end MaskedLMOutput.new( loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions ) end |