Class: Informers::Utils::MinLengthLogitsProcessor

Inherits:
LogitsProcessor show all
Defined in:
lib/informers/utils/generation.rb

Instance Method Summary collapse

Constructor Details

#initialize(min_length, eos_token_id) ⇒ MinLengthLogitsProcessor

Returns a new instance of MinLengthLogitsProcessor.



251
252
253
254
255
# File 'lib/informers/utils/generation.rb', line 251

def initialize(min_length, eos_token_id)
  super()
  @min_length = min_length
  @eos_token_id = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]
end

Instance Method Details

#call(input_ids, logits) ⇒ Object



257
258
259
260
261
262
263
264
265
# File 'lib/informers/utils/generation.rb', line 257

def call(input_ids, logits)
  if input_ids.length < @min_length
    @eos_token_id.each do |eos_token|
      logits[eos_token] = -Float::INFINITY
    end
  end

  logits
end