Class: Informers::Utils::NoRepeatNGramLogitsProcessor
- Inherits:
-
LogitsProcessor
- Object
- LogitsProcessor
- Informers::Utils::NoRepeatNGramLogitsProcessor
- Defined in:
- lib/informers/utils/generation.rb
Instance Method Summary collapse
- #calc_banned_ngram_tokens(prev_input_ids) ⇒ Object
- #call(input_ids, logits) ⇒ Object
- #get_generated_ngrams(banned_ngrams, prev_input_ids) ⇒ Object
- #get_ngrams(prev_input_ids) ⇒ Object
-
#initialize(no_repeat_ngram_size) ⇒ NoRepeatNGramLogitsProcessor
constructor
A new instance of NoRepeatNGramLogitsProcessor.
Constructor Details
#initialize(no_repeat_ngram_size) ⇒ NoRepeatNGramLogitsProcessor
Returns a new instance of NoRepeatNGramLogitsProcessor.
192 193 194 195 |
# File 'lib/informers/utils/generation.rb', line 192 def initialize(no_repeat_ngram_size) super() @no_repeat_ngram_size = no_repeat_ngram_size end |
Instance Method Details
#calc_banned_ngram_tokens(prev_input_ids) ⇒ Object
228 229 230 231 232 233 234 235 236 237 238 |
# File 'lib/informers/utils/generation.rb', line 228 def calc_banned_ngram_tokens(prev_input_ids) banned_tokens = [] if prev_input_ids.length + 1 < @no_repeat_ngram_size # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet banned_tokens else generated_ngrams = get_ngrams(prev_input_ids) banned_tokens = get_generated_ngrams(generated_ngrams, prev_input_ids) banned_tokens end end |
#call(input_ids, logits) ⇒ Object
240 241 242 243 244 245 246 247 |
# File 'lib/informers/utils/generation.rb', line 240 def call(input_ids, logits) banned_tokens = calc_banned_ngram_tokens(input_ids) banned_tokens.each do |token| logits[token] = -Float::INFINITY end logits end |
#get_generated_ngrams(banned_ngrams, prev_input_ids) ⇒ Object
222 223 224 225 226 |
# File 'lib/informers/utils/generation.rb', line 222 def get_generated_ngrams(banned_ngrams, prev_input_ids) ngram_idx = prev_input_ids.slice(prev_input_ids.length + 1 - @no_repeat_ngram_size, prev_input_ids.length) banned = banned_ngrams[JSON.generate(ngram_idx)] || [] banned end |
#get_ngrams(prev_input_ids) ⇒ Object
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
# File 'lib/informers/utils/generation.rb', line 197 def get_ngrams(prev_input_ids) cur_len = prev_input_ids.length ngrams = [] j = 0 while j < cur_len + 1 - @no_repeat_ngram_size ngram = [] @no_repeat_ngram_size.times do |k| ngram << prev_input_ids[j + k] end ngrams << ngram j += 1 end generated_ngram = {} ngrams.each do |ngram| prev_ngram = ngram.slice(0, ngram.length - 1) prev_ngram_key = JSON.generate(prev_ngram) prev_ngram_value = generated_ngram[prev_ngram_key] || [] prev_ngram_value << ngram[ngram.length - 1] generated_ngram[prev_ngram_key] = prev_ngram_value end generated_ngram end |