Class: Transformers::Distilbert::DistilBertForSequenceClassification

Inherits:
DistilBertPreTrainedModel show all
Defined in:
lib/transformers/models/distilbert/modeling_distilbert.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from DistilBertPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ DistilBertForSequenceClassification

Returns a new instance of DistilBertForSequenceClassification.



486
487
488
489
490
491
492
493
494
495
496
497
498
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 486

def initialize(config)
  super(config)
  @num_labels = config.num_labels
  @config = config

  @distilbert = DistilBertModel.new(config)
  @pre_classifier = Torch::NN::Linear.new(config.dim, config.dim)
  @classifier = Torch::NN::Linear.new(config.dim, config.num_labels)
  @dropout = Torch::NN::Dropout.new(p: config.seq_classif_dropout)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 500

def forward(
  input_ids: nil,
  attention_mask: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  distilbert_output =
    @distilbert.(
      input_ids: input_ids,
      attention_mask: attention_mask,
      head_mask: head_mask,
      inputs_embeds: inputs_embeds,
      output_attentions: output_attentions,
      output_hidden_states: output_hidden_states,
      return_dict: return_dict
    )
  hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
  pooled_output = hidden_state[0.., 0]  # (bs, dim)
  pooled_output = @pre_classifier.(pooled_output)  # (bs, dim)
  pooled_output = Torch::NN::ReLU.new.(pooled_output)  # (bs, dim)
  pooled_output = @dropout.(pooled_output)  # (bs, dim)
  logits = @classifier.(pooled_output)  # (bs, num_labels)

  loss = nil
  if !labels.nil?
    raise Todo
  end

  if !return_dict
    raise Todo
  end

  SequenceClassifierOutput.new(
    loss: loss,
    logits: logits,
    hidden_states: distilbert_output.hidden_states,
    attentions: distilbert_output.attentions
  )
end