Class: Informers::Utils::GenerationConfig

Inherits:
Object
  • Object
show all
Defined in:
lib/informers/utils/generation.rb

Instance Method Summary collapse

Constructor Details

#initialize(kwargs) ⇒ GenerationConfig

Returns a new instance of GenerationConfig.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# File 'lib/informers/utils/generation.rb', line 4

def initialize(kwargs)
  @config = {}

  # Parameters that control the length of the output
  @config["max_length"] = kwargs["max_length"] || 20
  @config["max_new_tokens"] = kwargs["max_new_tokens"]
  @config["min_length"] = kwargs["min_length"] || 0
  @config["min_new_tokens"] = kwargs["min_new_tokens"]
  @config["early_stopping"] = kwargs["early_stopping"] || false
  @config["max_time"] = kwargs["max_time"]

  # Parameters that control the generation strategy used
  @config["do_sample"] = kwargs["do_sample"] || false
  @config["num_beams"] = kwargs["num_beams"] || 1
  @config["num_beam_groups"] = kwargs["num_beam_groups"] || 1
  @config["penalty_alpha"] = kwargs["penalty_alpha"]
  @config["use_cache"] = kwargs.fetch("use_cache", true)

  # Parameters for manipulation of the model output logits
  @config["temperature"] = kwargs["temperature"] || 1.0
  @config["top_k"] = kwargs["top_k"] || 50
  @config["top_p"] = kwargs["top_p"] || 1.0
  @config["typical_p"] = kwargs["typical_p"] || 1.0
  @config["epsilon_cutoff"] = kwargs["epsilon_cutoff"] || 0.0
  @config["eta_cutoff"] = kwargs["eta_cutoff"] || 0.0
  @config["diversity_penalty"] = kwargs["diversity_penalty"] || 0.0
  @config["repetition_penalty"] = kwargs["repetition_penalty"] || 1.0
  @config["encoder_repetition_penalty"] = kwargs["encoder_repetition_penalty"] || 1.0
  @config["length_penalty"] = kwargs["length_penalty"] || 1.0
  @config["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"] || 0
  @config["bad_words_ids"] = kwargs["bad_words_ids"]
  @config["force_words_ids"] = kwargs["force_words_ids"]
  @config["renormalize_logits"] = kwargs["renormalize_logits"] || false
  @config["constraints"] = kwargs["constraints"]
  @config["forced_bos_token_id"] = kwargs["forced_bos_token_id"]
  @config["forced_eos_token_id"] = kwargs["forced_eos_token_id"]
  @config["remove_invalid_values"] = kwargs["remove_invalid_values"] || false
  @config["exponential_decay_length_penalty"] = kwargs["exponential_decay_length_penalty"]
  @config["suppress_tokens"] = kwargs["suppress_tokens"]
  @config["begin_suppress_tokens"] = kwargs["begin_suppress_tokens"]
  @config["forced_decoder_ids"] = kwargs["forced_decoder_ids"]

  # Parameters that define the output variables of `generate`
  @config["num_return_sequences"] = kwargs["num_return_sequences"] || 1
  @config["output_attentions"] = kwargs["output_attentions"] || false
  @config["output_hidden_states"] = kwargs["output_hidden_states"] || false
  @config["output_scores"] = kwargs["output_scores"] || false
  @config["return_dict_in_generate"] = kwargs["return_dict_in_generate"] || false

  # Special tokens that can be used at generation time
  @config["pad_token_id"] = kwargs["pad_token_id"]
  @config["bos_token_id"] = kwargs["bos_token_id"]
  @config["eos_token_id"] = kwargs["eos_token_id"]

  # Generation parameters exclusive to encoder-decoder models
  @config["encoder_no_repeat_ngram_size"] = kwargs["encoder_no_repeat_ngram_size"] || 0
  @config["decoder_start_token_id"] = kwargs["decoder_start_token_id"]

  # Wild card
  @generation_kwargs = kwargs["generation_kwargs"] || {}
end

Instance Method Details

#[](key) ⇒ Object



66
67
68
# File 'lib/informers/utils/generation.rb', line 66

def [](key)
  @config[key.to_s]
end

#merge!(config) ⇒ Object



70
71
72
# File 'lib/informers/utils/generation.rb', line 70

def merge!(config)
  @config.merge!(config)
end