Class: Informers::Utils::ForcedBOSTokenLogitsProcessor

Inherits:
LogitsProcessor show all
Defined in:
lib/informers/utils/generation.rb

Instance Method Summary collapse

Constructor Details

#initialize(bos_token_id) ⇒ ForcedBOSTokenLogitsProcessor

Returns a new instance of ForcedBOSTokenLogitsProcessor.



269
270
271
272
# File 'lib/informers/utils/generation.rb', line 269

def initialize(bos_token_id)
  super()
  @bos_token_id = bos_token_id
end

Instance Method Details

#call(input_ids, logits) ⇒ Object



274
275
276
277
278
279
280
# File 'lib/informers/utils/generation.rb', line 274

def call(input_ids, logits)
  if input_ids.length == 1
    logits.map! { -Float::INFINITY }
    logits[@bos_token_id] = 0
  end
  logits
end