Class: Informers::Utils::NoRepeatNGramLogitsProcessor

Inherits:
LogitsProcessor show all
Defined in:
lib/informers/utils/generation.rb

Instance Method Summary collapse

Constructor Details

#initialize(no_repeat_ngram_size) ⇒ NoRepeatNGramLogitsProcessor

Returns a new instance of NoRepeatNGramLogitsProcessor.



192
193
194
195
# File 'lib/informers/utils/generation.rb', line 192

def initialize(no_repeat_ngram_size)
  super()
  @no_repeat_ngram_size = no_repeat_ngram_size
end

Instance Method Details

#calc_banned_ngram_tokens(prev_input_ids) ⇒ Object



228
229
230
231
232
233
234
235
236
237
238
# File 'lib/informers/utils/generation.rb', line 228

def calc_banned_ngram_tokens(prev_input_ids)
  banned_tokens = []
  if prev_input_ids.length + 1 < @no_repeat_ngram_size
    # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
    banned_tokens
  else
    generated_ngrams = get_ngrams(prev_input_ids)
    banned_tokens = get_generated_ngrams(generated_ngrams, prev_input_ids)
    banned_tokens
  end
end

#call(input_ids, logits) ⇒ Object



240
241
242
243
244
245
246
247
# File 'lib/informers/utils/generation.rb', line 240

def call(input_ids, logits)
  banned_tokens = calc_banned_ngram_tokens(input_ids)

  banned_tokens.each do |token|
    logits[token] = -Float::INFINITY
  end
  logits
end

#get_generated_ngrams(banned_ngrams, prev_input_ids) ⇒ Object



222
223
224
225
226
# File 'lib/informers/utils/generation.rb', line 222

def get_generated_ngrams(banned_ngrams, prev_input_ids)
  ngram_idx = prev_input_ids.slice(prev_input_ids.length + 1 - @no_repeat_ngram_size, prev_input_ids.length)
  banned = banned_ngrams[JSON.generate(ngram_idx)] || []
  banned
end

#get_ngrams(prev_input_ids) ⇒ Object



197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# File 'lib/informers/utils/generation.rb', line 197

def get_ngrams(prev_input_ids)
  cur_len = prev_input_ids.length

  ngrams = []
  j = 0
  while j < cur_len + 1 - @no_repeat_ngram_size
    ngram = []
    @no_repeat_ngram_size.times do |k|
      ngram << prev_input_ids[j + k]
    end
    ngrams << ngram
    j += 1
  end

  generated_ngram = {}
  ngrams.each do |ngram|
    prev_ngram = ngram.slice(0, ngram.length - 1)
    prev_ngram_key = JSON.generate(prev_ngram)
    prev_ngram_value = generated_ngram[prev_ngram_key] || []
    prev_ngram_value << ngram[ngram.length - 1]
    generated_ngram[prev_ngram_key] = prev_ngram_value
  end
  generated_ngram
end