Class: Transformers::Mpnet::MPNetForMultipleChoice

Inherits:
MPNetPreTrainedModel show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from MPNetPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from Transformers::ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ MPNetForMultipleChoice

Returns a new instance of MPNetForMultipleChoice.



607
608
609
610
611
612
613
614
615
616
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 607

def initialize(config)
  super(config)

  @mpnet = MPNetModel.new(config)
  @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
  @classifier = Torch::NN::Linear.new(config.hidden_size, 1)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 618

def forward(
  input_ids: nil,
  attention_mask: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
  num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]

  flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
  flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
  flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
  flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil

  outputs = @mpnet.(flat_input_ids, position_ids: flat_position_ids, attention_mask: flat_attention_mask, head_mask: head_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
  pooled_output = outputs[1]

  pooled_output = @dropout.(pooled_output)
  logits = @classifier.(pooled_output)
  reshaped_logits = logits.view(-1, num_choices)

  loss = nil
  if !labels.nil?
    loss_fct = Torch::NN::CrossEntropyLoss.new
    loss = loss_fct.(reshaped_logits, labels)
  end

  if !return_dict
    output = [reshaped_logits] + outputs[2..]
    return !loss.nil? ? [loss] + output : output
  end

  MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
end