Class: Transformers::XlmRoberta::XLMRobertaForCausalLM
- Inherits:
-
XLMRobertaPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- XLMRobertaPreTrainedModel
- Transformers::XlmRoberta::XLMRobertaForCausalLM
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #_reorder_cache(past_key_values, beam_idx) ⇒ Object
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, past_key_values: nil, use_cache: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
- #get_output_embeddings ⇒ Object
-
#initialize(config) ⇒ XLMRobertaForCausalLM
constructor
A new instance of XLMRobertaForCausalLM.
- #prepare_inputs_for_generation(input_ids, past_key_values: nil, attention_mask: nil, **model_kwargs) ⇒ Object
- #set_output_embeddings(new_embeddings) ⇒ Object
Methods inherited from XLMRobertaPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ XLMRobertaForCausalLM
Returns a new instance of XLMRobertaForCausalLM.
756 757 758 759 760 761 762 763 764 765 766 767 768 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 756 def initialize(config) super(config) if !config.is_decoder Transformers.logger.warn("If you want to use `XLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") end @roberta = XLMRobertaModel.new(config, add_pooling_layer: false) @lm_head = XLMRobertaLMHead.new(config) # Initialize weights and apply final processing post_init end |
Instance Method Details
#_reorder_cache(past_key_values, beam_idx) ⇒ Object
848 849 850 851 852 853 854 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 848 def _reorder_cache(past_key_values, beam_idx) reordered_past = [] past_key_values.each do |layer_past| reordered_past += [Array(layer_past.select { |past_state| past_state })] end reordered_past end |
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, past_key_values: nil, use_cache: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 778 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, encoder_hidden_states: nil, encoder_attention_mask: nil, labels: nil, past_key_values: nil, use_cache: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict if !labels.nil? use_cache = false end outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , encoder_hidden_states: encoder_hidden_states, encoder_attention_mask: encoder_attention_mask, past_key_values: past_key_values, use_cache: use_cache, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] prediction_scores = @lm_head.(sequence_output) lm_loss = nil if !labels.nil? # move labels to correct device to enable model parallelism labels = labels.to(prediction_scores.device) # we are doing next-token prediction; shift prediction scores and input ids by one shifted_prediction_scores = prediction_scores[0.., ...-1, 0..].contiguous labels = labels[0.., 1..].contiguous loss_fct = Torch::NN::CrossEntropyLoss.new lm_loss = loss_fct.(shifted_prediction_scores.view(-1, @config.vocab_size), labels.view(-1)) end if !return_dict output = [prediction_scores] + outputs[2..] return !lm_loss.nil? ? [lm_loss] + output : output end CausalLMOutputWithCrossAttentions.new(loss: lm_loss, logits: prediction_scores, past_key_values: outputs.past_key_values, hidden_states: outputs.hidden_states, attentions: outputs.attentions, cross_attentions: outputs.cross_attentions) end |
#get_output_embeddings ⇒ Object
770 771 772 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 770 def @lm_head.decoder end |
#prepare_inputs_for_generation(input_ids, past_key_values: nil, attention_mask: nil, **model_kwargs) ⇒ Object
823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 823 def prepare_inputs_for_generation(input_ids, past_key_values: nil, attention_mask: nil, **model_kwargs) input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask.nil? attention_mask = input_ids.new_ones(input_shape) end # cut decoder_input_ids if past_key_values is used if !past_key_values.nil? past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length remove_prefix_length = past_length else # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 end input_ids = input_ids[0.., remove_prefix_length..] end {"input_ids" => input_ids, "attention_mask" => attention_mask, "past_key_values" => past_key_values} end |
#set_output_embeddings(new_embeddings) ⇒ Object
774 775 776 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 774 def () @decoder = end |