Class: Transformers::XlmRoberta::XLMRobertaForQuestionAnswering
- Inherits:
-
XLMRobertaPreTrainedModel
- Object
- Torch::NN::Module
- PreTrainedModel
- XLMRobertaPreTrainedModel
- Transformers::XlmRoberta::XLMRobertaForQuestionAnswering
- Defined in:
- lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb
Instance Attribute Summary
Attributes inherited from PreTrainedModel
Instance Method Summary collapse
- #forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
-
#initialize(config) ⇒ XLMRobertaForQuestionAnswering
constructor
A new instance of XLMRobertaForQuestionAnswering.
Methods inherited from XLMRobertaPreTrainedModel
Methods inherited from PreTrainedModel
#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask
Methods included from ClassAttribute
Methods included from ModuleUtilsMixin
#device, #get_extended_attention_mask, #get_head_mask
Constructor Details
#initialize(config) ⇒ XLMRobertaForQuestionAnswering
Returns a new instance of XLMRobertaForQuestionAnswering.
1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 1152 def initialize(config) super(config) @num_labels = config.num_labels @roberta = XLMRobertaModel.new(config, add_pooling_layer: false) @qa_outputs = Torch::NN::Linear.new(config.hidden_size, config.num_labels) # Initialize weights and apply final processing post_init end |
Instance Method Details
#forward(input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object
1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 |
# File 'lib/transformers/models/xlm_roberta/modeling_xlm_roberta.rb', line 1163 def forward( input_ids: nil, attention_mask: nil, token_type_ids: nil, position_ids: nil, head_mask: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil ) return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict outputs = @roberta.(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids, position_ids: position_ids, head_mask: head_mask, inputs_embeds: , output_attentions: output_attentions, output_hidden_states: output_hidden_states, return_dict: return_dict) sequence_output = outputs[0] logits = @qa_outputs.(sequence_output) start_logits, end_logits = logits.split(1, dim: -1) start_logits = start_logits.squeeze(-1).contiguous end_logits = end_logits.squeeze(-1).contiguous total_loss = nil if !start_positions.nil? && !end_positions.nil? # If we are on multi-GPU, split add a dimension if start_positions.size.length > 1 start_positions = start_positions.squeeze(-1) end if end_positions.size.length > 1 end_positions = end_positions.squeeze(-1) end # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = Torch::NN::CrossEntropyLoss.new(ignore_index: ignored_index) start_loss = loss_fct.(start_logits, start_positions) end_loss = loss_fct.(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 end if !return_dict output = [start_logits, end_logits] + outputs[2..] return !total_loss.nil? ? [total_loss] + output : output end QuestionAnsweringModelOutput.new(loss: total_loss, start_logits: start_logits, end_logits: end_logits, hidden_states: outputs.hidden_states, attentions: outputs.attentions) end |