Class: Transformers::XlmRoberta::XLMRobertaForMaskedLM
- Inherits:
-
XLMRobertaPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- XLMRobertaPreTrainedModel
- Transformers::XlmRoberta::XLMRobertaForMaskedLM
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
- #get_output_embeddings ⇒ Object
-
#initialize(config) ⇒ XLMRobertaForMaskedLM
constructor
A new instance of XLMRobertaForMaskedLM.
- #set_output_embeddings(new_embeddings) ⇒ Object
Methods inherited from XLMRobertaPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ XLMRobertaForMaskedLM
Returns a new instance of XLMRobertaForMaskedLM.
860 861 862 863 864 865 866 867 868 869 870 871 872 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 860 def initialize(config) super(config) if config.is_decoder Transformers.logger.warn("If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder: false` for bi-directional self-attention.") end @roberta = XLMRobertaModel.new(config, add_pooling_layer: false) @lm_head = XLMRobertaLMHead.new(config) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 882 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] prediction_scores = @lm_head.(sequence_output) masked_lm_loss = nil if !labels.nil? # move labels to correct device to enable model parallelism labels = labels.to(prediction_scores.device) loss_fct = Torch::NN::CrossEntropyLoss.new masked_lm_loss = loss_fct.(prediction_scores.view(-1, @config.vocab_size), labels.view(-1)) end if !return_dict output = [prediction_scores] + outputs[2..] return !masked_lm_loss.nil? ? [masked_lm_loss] + output : output end MaskedLMOutput.new(loss: masked_lm_loss, logits: prediction_scores, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |
#get_output_embeddings ⇒ Object
874 875 876 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 874 def @lm_head.decoder end |
#set_output_embeddings(new_embeddings) ⇒ Object
878 879 880 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 878 def () @decoder = end |