Class: Informers::Utils::BeamSearchSampler

Inherits:
Sampler
  • Object
show all
Defined in:
lib/informers/utils/generation.rb

Instance Method Summary collapse

Methods inherited from Sampler

#call, #get_logits, get_sampler, #initialize

Constructor Details

This class inherits a constructor from Informers::Utils::Sampler

Instance Method Details

#sample(logits, index = -1)) ⇒ Object



134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# File 'lib/informers/utils/generation.rb', line 134

def sample(logits, index = -1)
  k = Utils.dims(logits)[-1] # defaults to vocab size
  if @generation_config["top_k"] > 0
    k = [@generation_config["top_k"], k].min
  end

  # Get logits of nth token
  logs = get_logits(logits, index)

  # Get top k tokens
  top_logits = Utils.get_top_items(logs, k)

  # Compute softmax over logits
  probabilities = Utils.softmax(top_logits.map { |x| x[1] })

  Array.new(@generation_config["num_beams"]) do |i|
    [
      top_logits[i][0],
      Math.log(probabilities[i])
    ]
  end
end