Class: Transformers::Mpnet::MPNetForTokenClassification
- Inherits:
-
MPNetPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- MPNetPreTrainedModel
- Transformers::Mpnet::MPNetForTokenClassification
- Defined in:
- lib/transformers/models/mpnet/modeling_mpnet.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ MPNetForTokenClassification
constructor
A new instance of MPNetForTokenClassification.
Methods inherited from MPNetPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from Transformers::ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ MPNetForTokenClassification
Returns a new instance of MPNetForTokenClassification.
660 661 662 663 664 665 666 667 668 669 670 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 660 def initialize(config) super(config) @num_labels = config.num_labels @mpnet = MPNetModel.new(config, add_pooling_layer: false) @dropout = Torch::NN::Dropout.new(p: config.hidden_dropout_prob) @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 672 def forward( input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] sequence_output = @dropout.(sequence_output) logits = @classifier.(sequence_output) loss = nil if !labels.nil? loss_fct = Torch::NN::CrossEntropyLoss.new loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1)) end if !return_dict output = [logits] + outputs[2..] return !loss.nil? ? [loss] + output : output end TokenClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |