Class: Transformers::Distilbert::DistilBertForQuestionAnswering

Inherits:
DistilBertPreTrainedModel show all
Defined in:
lib/transformers/models/distilbert/modeling_distilbert.rb

Instance Attribute Summary

Attributes inherited from PreTrainedModel

#config

Instance Method Summary collapse

Methods inherited from DistilBertPreTrainedModel

#_init_weights

Methods inherited from PreTrainedModel

#_backward_compatibility_gradient_checkpointing, #_init_weights, #_initialize_weights, #base_model, #can_generate, #dequantize, #dummy_inputs, #framework, from_pretrained, #get_input_embeddings, #get_output_embeddings, #init_weights, #post_init, #prune_heads, #set_input_embeddings, #tie_weights, #warn_if_padding_and_no_attention_mask

Methods included from ClassAttribute

#class_attribute

Methods included from ModuleUtilsMixin

#device, #get_extended_attention_mask, #get_head_mask

Constructor Details

#initialize(config) ⇒ DistilBertForQuestionAnswering

Returns a new instance of DistilBertForQuestionAnswering.



548
549
550
551
552
553
554
555
556
557
558
559
560
561
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 548

def initialize(config)
  super(config)

  @distilbert = DistilBertModel.new(config)
  @qa_outputs = Torch::NN::Linear.new(config.dim, config.num_labels)
  if config.num_labels != 2
    raise ArgumentError, "config.num_labels should be 2, but it is #{config.num_labels}"
  end

  @dropout = Torch::NN::Dropout.new(p: config.qa_dropout)

  # Initialize weights and apply final processing
  post_init
end

Instance Method Details

#forward(input_ids: nil, attention_mask: nil, head_mask: nil, inputs_embeds: nil, start_positions: nil, end_positions: nil, output_attentions: nil, output_hidden_states: nil, return_dict: nil) ⇒ Object



563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
# File 'lib/transformers/models/distilbert/modeling_distilbert.rb', line 563

def forward(
  input_ids: nil,
  attention_mask: nil,
  head_mask: nil,
  inputs_embeds: nil,
  start_positions: nil,
  end_positions: nil,
  output_attentions: nil,
  output_hidden_states: nil,
  return_dict: nil
)
  return_dict = !return_dict.nil? ? return_dict : @config.use_return_dict

  distilbert_output = @distilbert.(
    input_ids: input_ids,
    attention_mask: attention_mask,
    head_mask: head_mask,
    inputs_embeds: inputs_embeds,
    output_attentions: output_attentions,
    output_hidden_states: output_hidden_states,
    return_dict: return_dict
  )
  hidden_states = distilbert_output[0]  # (bs, max_query_len, dim)

  hidden_states = @dropout.(hidden_states)  # (bs, max_query_len, dim)
  logits = @qa_outputs.(hidden_states)  # (bs, max_query_len, 2)
  start_logits, end_logits = logits.split(1, dim: -1)
  start_logits = start_logits.squeeze(-1).contiguous  # (bs, max_query_len)
  end_logits = end_logits.squeeze(-1).contiguous  # (bs, max_query_len)

  total_loss = nil
  if !start_positions.nil? && !end_positions.nil?
    raise Todo
  end

  if !return_dict
    raise Todo
  end

  QuestionAnsweringModelOutput.new(
    loss: total_loss,
    start_logits: start_logits,
    end_logits: end_logits,
    hidden_states: distilbert_output.hidden_states,
    attentions: distilbert_output.attentions
  )
end