Class: Transformers::DebertaV2::DebertaV2ForMultipleChoice
- Inherits:
-
DebertaV2PreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- DebertaV2PreTrainedModel
- Transformers::DebertaV2::DebertaV2ForMultipleChoice
- Defined in:
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
- #get_input_embeddings ⇒ Object
-
#initialize(config) ⇒ DebertaV2ForMultipleChoice
constructor
A new instance of DebertaV2ForMultipleChoice.
- #set_input_embeddings(new_embeddings) ⇒ Object
Methods inherited from DebertaV2PreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ DebertaV2ForMultipleChoice
Returns a new instance of DebertaV2ForMultipleChoice.
1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1139 def initialize(config) super(config) num_labels = config.getattr("num_labels", 2) @num_labels = num_labels @deberta = DebertaV2Model.new(config) @pooler = ContextPooler.new(config) output_dim = @pooler.output_dim @classifier = Torch::NN::Linear.new(output_dim, 1) drop_out = config.getattr("cls_dropout", nil) drop_out = drop_out.nil? ? @config.hidden_dropout_prob : drop_out @dropout = StableDropout.new(drop_out) init_weights end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1165 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict num_choices = !input_ids.nil? ? input_ids.shape[1] : .shape[1] flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil flat_token_type_ids = !token_type_ids.nil? ? token_type_ids.view(-1, token_type_ids.size(-1)) : nil flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil = !.nil? ? .view(-1, .size(-2), .size(-1)) : nil outputs = @deberta.(flat_input_ids, position_ids: flat_position_ids, token_type_ids: flat_token_type_ids, attention_mask: flat_attention_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) encoder_layer = outputs[0] pooled_output = @pooler.(encoder_layer) pooled_output = @dropout.(pooled_output) logits = @classifier.(pooled_output) reshaped_logits = logits.view(-1, num_choices) loss = nil if !labels.nil? loss_fct = Torch::NN::CrossEntropyLoss.new loss = loss_fct.(reshaped_logits, labels) end if !return_dict output = [reshaped_logits] + outputs[1..] return !loss.nil? ? [loss] + output : output end MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |
#get_input_embeddings ⇒ Object
1157 1158 1159 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1157 def @deberta. end |
#set_input_embeddings(new_embeddings) ⇒ Object
1161 1162 1163 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1161 def () @deberta.() end |