Class: Transformers::Mpnet::MPNetForMultipleChoice
- Inherits:
-
MPNetPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- MPNetPreTrainedModel
- Transformers::Mpnet::MPNetForMultipleChoice
- Defined in:
- lib/transformers/models/mpnet/modeling_mpnet.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ MPNetForMultipleChoice
constructor
A new instance of MPNetForMultipleChoice.
Methods inherited from MPNetPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from Transformers::ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ MPNetForMultipleChoice
Returns a new instance of MPNetForMultipleChoice.
607 608 609 610 611 612 613 614 615 616 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 607 def initialize(config) super(config) @mpnet = MPNetModel.new(config) @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob) @classifier = Torch::NN::Linear.new(config.hidden_size, 1) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 618 def forward( input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict num_choices = !input_ids.nil? ? input_ids.shape[1] : .shape[1] flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil = !.nil? ? .view(-1, .size(-2), .size(-1)) : nil outputs = @mpnet.(flat_input_ids, position_ids: flat_position_ids, attention_mask: flat_attention_mask, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) pooled_output = outputs[1] pooled_output = @dropout.(pooled_output) logits = @classifier.(pooled_output) reshaped_logits = logits.view(-1, num_choices) loss = nil if !labels.nil? loss_fct = Torch::NN::CrossEntropyLoss.new loss = loss_fct.(reshaped_logits, labels) end if !return_dict output = [reshaped_logits] + outputs[2..] return !loss.nil? ? [loss] + output : output end MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |