Class: Transformers::Bert::BertForSequenceClassification
- Inherits:
-
BertPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- BertPreTrainedModel
- Transformers::Bert::BertForSequenceClassification
- Defined in:
- lib/transformers/models/bert/modeling_bert.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ BertForSequenceClassification
constructor
A new instance of BertForSequenceClassification.
Methods inherited from BertPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ BertForSequenceClassification
Returns a new instance of BertForSequenceClassification.
834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 834 def initialize(config) super @num_labels = config.num_labels @config = config @bert = BertModel.new(config, add_pooling_layer: true) classifier_dropout = ( config.classifier_dropout.nil? ? config.hidden_dropout_prob : config.classifier_dropout ) @dropout = Torch::NN::Dropout.new(p: classifier_dropout) @classifier = Torch::NN::Linear.new(config.hidden_size, config.num_labels) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 |
# File 'lib/transformers/models/bert/modeling_bert.rb', line 850 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = @config.use_return_dict if return_dict.nil? outputs = @bert.( input_ids: input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict ) pooled_output = outputs[1] pooled_output = @dropout.(pooled_output) logits = @classifier.(pooled_output) loss = nil if !labels.nil? if @config.problem_type.nil? if @num_labels == 1 @config.problem_type = "regression" elsif @num_labels > 1 && (labels.dtype == Torch.long || labels.dtype == Torch.int) @config.problem_type = "single_label_classification" else @config.problem_type = "multi_label_classification" end end if @config.problem_type == "regression" loss_fct = Torch::NN::MSELoss.new if @num_labels == 1 loss = loss_fct.(logits.squeeze, labels.squeeze) else loss = loss_fct.(logits, labels) end elsif @config.problem_type == "single_label_classification" loss_fct = Torch::NN::CrossEntropyLoss.new loss = loss_fct.(logits.view(-1, @num_labels), labels.view(-1)) elsif @config.problem_type == "multi_label_classification" loss_fct = Torch::NN::BCEWithLogitsLoss.new loss = loss_fct.(logits, labels) end end if !return_dict raise Todo end SequenceClassifierOutput.new( loss: loss, logits: logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions ) end |