Class: Informers::FillMaskPipeline
- Defined in:
- lib/informers/pipelines.rb
Instance Method Summary collapse
Methods inherited from Pipeline
Constructor Details
This class inherits a constructor from Informers::Pipeline
Instance Method Details
#call(texts, top_k: 5) ⇒ Object
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 |
# File 'lib/informers/pipelines.rb', line 281 def call(texts, top_k: 5) model_inputs = @tokenizer.(texts, padding: true, truncation: true) outputs = @model.(model_inputs) to_return = [] model_inputs[:input_ids].each_with_index do |ids, i| mask_token_index = ids.index(@tokenizer.mask_token_id) if mask_token_index.nil? raise ArgumentError, "Mask token (#{@tokenizer.mask_token}) not found in text." end logits = outputs.logits[i] item_logits = logits[mask_token_index] scores = Utils.get_top_items(Utils.softmax(item_logits), top_k) to_return << scores.map do |x| sequence = ids.dup sequence[mask_token_index] = x[0] { score: x[1], token: x[0], token_str: @tokenizer.id_to_token(x[0]), sequence: @tokenizer.decode(sequence, skip_special_tokens: true) } end end texts.is_a?(Array) ? to_return : to_return[0] end |