Class: Transformers::Mpnet::MPNetForSequenceClassification

Inherits:
MPNetPreTrainedModel show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from MPNetPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from Transformers::ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ MPNetForSequenceClassification

Returns a new instance of MPNetForSequenceClassification.



542
543
544
545
546
547
548
549
550
551
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 542

def initialize(config)
  super(config)

  @num_labels = config.num_labels
  @mpnet = MPNetModel.new(config, add_pooling_layer: false)
  @classifier = MPNetClassificationHead.new(config)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 553

def forward(
  input_ids: nil,
  attention_mask: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)
  sequence_output = outputs[0]
  logits = @classifier.(sequence_output)

  loss = nil
  if !labels.nil?
    if @config.problem_type.nil?
      if @num_labels == 1
        @problem_type = "regression"
      elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int
        @problem_type = "single_label_classification"
      else
        @problem_type = "multi_label_classification"
      end
    end

    if @config.problem_type == "regression"
      loss_fct = Torch::NN::MSELoss.new
      if @num_labels == 1
        loss = loss_fct.(logits.squeeze, labels.squeeze)
      else
        loss = loss_fct.(logits, labels)
      end
    elsif @config.problem_type == "single_label_classification"
      loss_fct = Torch::NN::CrossEntropyLoss.new
      loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
    elsif @config.problem_type == "multi_label_classification"
      loss_fct = Torch::NN::BCEWithLogitsLoss.new
      loss = loss_fct.(logits, labels)
    end
  end
  if !return_dict
    output = [logits] + outputs[2..]
    return !loss.nil? ? [loss] + output : output
  end

  SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
end