Class: Transformers::Mpnet::MPNetForTokenClassification

Inherits:
MPNetPreTrainedModel show all
Defined in:
lib/transformers/models/mpnet/modeling_mpnet.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from MPNetPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from Transformers::ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ MPNetForTokenClassification

Returns a new instance of MPNetForTokenClassification.



660
661
662
663
664
665
666
667
668
669
670
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 660

def initialize(config)
  super(config)
  @num_labels = config.num_labels

  @mpnet = MPNetModel.new(config, add_pooling_layer: false)
  @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob)
  @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 672

def forward(
  input_ids: nil,
  attention_mask: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)

  sequence_output = outputs[0]

  sequence_output = @dropout.(sequence_output)
  logits = @classifier.(sequence_output)

  loss = nil
  if !labels.nil?
    loss_fct = Torch::NN::CrossEntropyLoss.new
    loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
  end

  if !return_dict
    output = [logits] + outputs[2..]
    return !loss.nil? ? [loss] + output : output
  end

  TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
end