Class: Transformers::Mpnet::MPNetForSequenceClassification
- Inherits:
-
MPNetPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- MPNetPreTrainedModel
- Transformers::Mpnet::MPNetForSequenceClassification
- Defined in:
- lib/transformers/models/mpnet/modeling_mpnet.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ MPNetForSequenceClassification
constructor
A new instance of MPNetForSequenceClassification.
Methods inherited from MPNetPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from Transformers::ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ MPNetForSequenceClassification
Returns a new instance of MPNetForSequenceClassification.
542 543 544 545 546 547 548 549 550 551 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 542 def initialize(config) super(config) @num_labels = config.num_labels @mpnet = MPNetModel.new(config, add_pooling_layer: false) @classifier = MPNetClassificationHead.new(config) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 |
# File 'lib/transformers/models/mpnet/modeling_mpnet.rb', line 553 def forward( input_ids: nil, attention_mask: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @mpnet.(input_ids, attention_mask: attention_mask, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] logits = @classifier.(sequence_output) loss = nil if !labels.nil? if @config.problem_type.nil? if @num_labels == 1 @problem_type = "regression" elsif @num_labels > 1 && labels.dtype == Torch.long || labels.dtype == Torch.int @problem_type = "single_label_classification" else @problem_type = "multi_label_classification" end end if @config.problem_type == "regression" loss_fct = Torch::NN::MSELoss.new if @num_labels == 1 loss = loss_fct.(logits.squeeze, labels.squeeze) else loss = loss_fct.(logits, labels) end elsif @config.problem_type == "single_label_classification" loss_fct = Torch::NN::CrossEntropyLoss.new loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1)) elsif @config.problem_type == "multi_label_classification" loss_fct = Torch::NN::BCEWithLogitsLoss.new loss = loss_fct.(logits, labels) end end if !return_dict output = [logits] + outputs[2..] return !loss.nil? ? [loss] + output : output end SequenceClassifierOutput.new(loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |