Class: Transformers::DebertaV2::DebertaV2ForQuestionAnswering
- Inherits:
-
DebertaV2PreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- DebertaV2PreTrainedModel
- Transformers::DebertaV2::DebertaV2ForQuestionAnswering
- Defined in:
- lib/transformers/models/deberta_v2/modeling_deberta_v2.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ DebertaV2ForQuestionAnswering
constructor
A new instance of DebertaV2ForQuestionAnswering.
Methods inherited from DebertaV2PreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ DebertaV2ForQuestionAnswering
Returns a new instance of DebertaV2ForQuestionAnswering.
1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1074 def initialize(config) super(config) @num_labels = config.num_labels @deberta = DebertaV2Model.new(config) @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 |
# File 'lib/transformers/models/deberta_v2/modeling_deberta_v2.rb', line 1085 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @deberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] logits = @qa_outputs.(sequence_output) start_logits, end_logits = logits.split(1, dim: -1) start_logits = start_logits.squeeze(-1).contiguous end_logits = end_logits.squeeze(-1).contiguous total_loss = nil if !start_positions.nil? && !end_positions.nil? # If we are on multi-GPU, split add a dimension if start_positions.size.length > 1 start_positions = start_positions.squeeze(-1) end if end_positions.size.length > 1 end_positions = end_positions.squeeze(-1) end # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index) start_loss = loss_fct.(start_logits, start_positions) end_loss = loss_fct.(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 end if !return_dict output = [start_logits, end_logits] + outputs[1..] return !total_loss.nil? ? [total_loss] + output : output end QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |