Class: Transformers::DebertaV2::DebertaV2ForMultipleChoice

Inherits:
DebertaV2PreTrainedModel show all
Defined in:
lib/transformers/models/deberta_v2/modeling_deberta_v2.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from DebertaV2PreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ DebertaV2ForMultipleChoice

Returns a new instance of DebertaV2ForMultipleChoice.



1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1139

def initialize(config)
  super(config)

  num_labels = config.getattr("num_labels", 2)
  @num_labels = num_labels

  @deberta = DebertaV2Model.new(config)
  @pooler = ContextPooler.new(config)
  output_dim = @pooler.output_dim

  @classifier = Torch::NN::Linear.new(output_dim, 1)
  drop_out = config.getattr("cls_dropout", nil)
  drop_out = drop_out.nil? ? @config.hidden_dropout_prob : drop_out
  @dropout = StableDropout.new(drop_out)

  init_weights
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, labels: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1165

def forward(
  input_ids: nil,
  attention_mask: nil,
  token_type_ids: nil,
  position_ids: nil,
  inputs_embeds: nil,
  labels: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict
  num_choices = !input_ids.nil? ? input_ids.shape[1] : inputs_embeds.shape[1]

  flat_input_ids = !input_ids.nil? ? input_ids.view(-1, input_ids.size(-1)) : nil
  flat_position_ids = !position_ids.nil? ? position_ids.view(-1, position_ids.size(-1)) : nil
  flat_token_type_ids = !token_type_ids.nil? ? token_type_ids.view(-1, token_type_ids.size(-1)) : nil
  flat_attention_mask = !attention_mask.nil? ? attention_mask.view(-1, attention_mask.size(-1)) : nil
  flat_inputs_embeds = !inputs_embeds.nil? ? inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) : nil

  outputs = @deberta.(flat_input_ids, position_ids: flat_position_ids, token_type_ids: flat_token_type_ids, attention_mask: flat_attention_mask, inputs_embeds: flat_inputs_embeds, output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict)

  encoder_layer = outputs[0]
  pooled_output = @pooler.(encoder_layer)
  pooled_output = @dropout.(pooled_output)
  logits = @classifier.(pooled_output)
  reshaped_logits = logits.view(-1, num_choices)

  loss = nil
  if !labels.nil?
    loss_fct = Torch::NN::CrossEntropyLoss.new
    loss = loss_fct.(reshaped_logits, labels)
  end

  if !return_dict
    output = [reshaped_logits] + outputs[1..]
    return !loss.nil? ? [loss] + output : output
  end

  MultipleChoiceModelOutput.new(loss: loss, logits: reshaped_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions)
end

#get_input_embeddingsObject



1157
1158
1159
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1157

def get_input_embeddings
  @deberta.get_input_embeddings
end

#set_input_embeddings(new_embeddings) ⇒ Object



1161
1162
1163
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1161

def set_input_embeddings(new_embeddings)
  @deberta.set_input_embeddings(new_embeddings)
end