Class: Transformers::Distilbert::DistilBertForSequenceClassification
- Inherits:
-
DistilBertPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- DistilBertPreTrainedModel
- Transformers::Distilbert::DistilBertForSequenceClassification
- Defined in:
- lib/transformers/models/distilbert/modeling_distilbert.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ DistilBertForSequenceClassification
constructor
A new instance of DistilBertForSequenceClassification.
Methods inherited from DistilBertPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ DistilBertForSequenceClassification
Returns a new instance of DistilBertForSequenceClassification.
486 487 488 489 490 491 492 493 494 495 496 497 498 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 486 def initialize(config) super(config) @num_labels = config.num_labels @config = config @distilbert = DistilBertModel.new(config) @pre_classifier = Torch::NN::Linear.new(config.dim, config.dim) @classifier = Torch::NN::Linear.new(config.dim, config.num_labels) @dropout = Torch::NN::Dropout.new(p: config.seq_classif_dropout) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 |
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 500 def forward( input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict distilbert_output = @distilbert.( input_ids: input_ids, attention_mask: attention_mask, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict ) hidden_state = distilbert_output[0] # (bs, seq_len, dim) pooled_output = hidden_state[0.., 0] # (bs, dim) pooled_output = @pre_classifier.(pooled_output) # (bs, dim) pooled_output = Torch::NN::ReLU.new.(pooled_output) # (bs, dim) pooled_output = @dropout.(pooled_output) # (bs, dim) logits = @classifier.(pooled_output) # (bs, num_labels) loss = nil if !labels.nil? raise Todo end if !return_dict raise Todo end SequenceClassifierOutput.new( loss: loss, logits: logits, hidden_states: distilbert_output.hidden_states, attentions: distilbert_output.attentions ) end |