Class: Informers::FillMaskPipeline

Inherits:
Pipeline
  • Object
show all
Defined in:
lib/informers/pipelines.rb

Instance Method Summary collapse

Methods inherited from Pipeline

#initialize

Constructor Details

This class inherits a constructor from Informers::Pipeline

Instance Method Details

#call(texts, top_k: 5) ⇒ Object



281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
# File 'lib/informers/pipelines.rb', line 281

def call(texts, top_k: 5)
  model_inputs = @tokenizer.(texts, padding: true, truncation: true)
  outputs = @model.(model_inputs)

  to_return = []
  model_inputs[:input_ids].each_with_index do |ids, i|
    mask_token_index = ids.index(@tokenizer.mask_token_id)

    if mask_token_index.nil?
      raise ArgumentError, "Mask token (#{@tokenizer.mask_token}) not found in text."
    end
    logits = outputs.logits[i]
    item_logits = logits[mask_token_index]

    scores = Utils.get_top_items(Utils.softmax(item_logits), top_k)

    to_return <<
      scores.map do |x|
        sequence = ids.dup
        sequence[mask_token_index] = x[0]

        {
          score: x[1],
          token: x[0],
          token_str: @tokenizer.id_to_token(x[0]),
          sequence: @tokenizer.decode(sequence, skip_special_tokens: true)
        }
      end
  end
  texts.is_a?(Array) ? to_return : to_return[0]
end