Class: Transformers::Bert::BertForSequenceClassification

Inherits:
BertPreTrainedModel show all
Defined in:
lib/transformers/models/bert/modeling_bert.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from BertPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ BertForSequenceClassification

Returns a new instance of BertForSequenceClassification.



834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
# File 'lib/transformers/models/bert/modeling_bert.rb', line 834

def initialize(config)
  super
  @num_labels = config.num_labels
  @config = config

  @bert = BertModel.new(config, add_pooling_layer: true)
  classifier_dropout = (
    config.classifier_dropout.nil? ? config.hidden_dropout_prob : config.classifier_dropout
  )
  @dropout = Torch::NN::Dropout.new(p: classifier_dropout)
  @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
# File 'lib/transformers/models/bert/modeling_bert.rb', line 850

def forward(
  input_ids: nil,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  head_mask: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = @config.use_return_dict if return_dict.nil?

  outputs = @bert.(
    input_ids: input_ids,
    attention_mask: attention_mask,
    token_type_ids: token_type_ids,
    position_ids: position_ids,
    head_mask: head_mask,
    inputs_embeds: inputs_embeds,
    output_attentions: output_attentions,
    output_hidden_states: output_hidden_states,
    return_dict: return_dict
  )

  pooled_output = outputs[1]

  pooled_output = @dropout.(pooled_output)
  logits = @classifier.(pooled_output)

  loss = nil
  if !labels.nil?
    if @config.problem_type.nil?
      if @num_labels == 1
        @config.problem_type = "regression"
      elsif @num_labels > 1 && (labels.dtype == Torch.long || labels.dtype == Torch.int)
        @config.problem_type = "single_label_classification"
      else
        @config.problem_type = "multi_label_classification"
      end
    end

    if @config.problem_type == "regression"
      loss_fct = Torch::NN::MSELoss.new
      if @num_labels == 1
        loss = loss_fct.(logits.squeeze, labels.squeeze)
      else
        loss = loss_fct.(logits, labels)
      end
    elsif @config.problem_type == "single_label_classification"
      loss_fct = Torch::NN::CrossEntropyLoss.new
      loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1))
    elsif @config.problem_type == "multi_label_classification"
      loss_fct = Torch::NN::BCEWithLogitsLoss.new
      loss = loss_fct.(logits, labels)
    end
  end

  if !return_dict
    raise Todo
  end

  SequenceClassifierOutput.new(
    loss: loss,
    logits: logits,
    hidden_states: outputs.hidden_states,
    attentions: outputs.attentions
  )
end