Class: Transformers::Mpnet::MPNetForQuestionAnswering
- Inherits:
-
MPNetPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- MPNetPreTrainedModel
- Transformers::Mpnet::MPNetForQuestionAnswering
- Defined in:
- lib/transformers/models/mpnet/modeling_mpnet.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ MPNetForQuestionAnswering
constructor
A new instance of MPNetForQuestionAnswering.
Methods inherited from MPNetPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from Transformers::ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ MPNetForQuestionAnswering
Returns a new instance of MPNetForQuestionAnswering.
727 728 729 730 731 732 733 734 735 736 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 727 def initialize(config) super(config) @num_labels = config.num_labels @mpnet = MPNetModel.new(config, add_pooling_layer: false) @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 738 def forward( input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] logits = @qa_outputs.(sequence_output) start_logits, end_logits = logits.split(1, dim: -1) start_logits = start_logits.squeeze(-1).contiguous end_logits = end_logits.squeeze(-1).contiguous total_loss = nil if !start_positions.nil? && !end_positions.nil? # If we are on multi-GPU, split add a dimension if start_positions.size.length > 1 start_positions = start_positions.squeeze(-1) end if end_positions.size.length > 1 end_positions = end_positions.squeeze(-1) end # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index) start_loss = loss_fct.(start_logits, start_positions) end_loss = loss_fct.(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 end if !return_dict output = [start_logits, end_logits] + outputs[2..] return !total_loss.nil? ? [total_loss] + output : output end QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |