Class: Transformers::PretrainedConfig
- Inherits:
-
Object
- Object
- Transformers::PretrainedConfig
- Extended by:
- ClassAttribute
- Defined in:
- lib/transformers/configuration_utils.rb
Direct Known Subclasses
Bert::BertConfig, DebertaV2::DebertaV2Config, Distilbert::DistilBertConfig, Mpnet::MPNetConfig, Vit::ViTConfig, XlmRoberta::XLMRobertaConfig
Instance Attribute Summary collapse
-
#_commit_hash ⇒ Object
readonly
Returns the value of attribute _commit_hash.
-
#add_cross_attention ⇒ Object
readonly
Returns the value of attribute add_cross_attention.
-
#architectures ⇒ Object
readonly
Returns the value of attribute architectures.
-
#chunk_size_feed_forward ⇒ Object
readonly
Returns the value of attribute chunk_size_feed_forward.
-
#id2label ⇒ Object
readonly
Returns the value of attribute id2label.
-
#is_decoder ⇒ Object
readonly
Returns the value of attribute is_decoder.
-
#is_encoder_decoder ⇒ Object
readonly
Returns the value of attribute is_encoder_decoder.
-
#output_attentions ⇒ Object
readonly
Returns the value of attribute output_attentions.
-
#output_hidden_states ⇒ Object
readonly
Returns the value of attribute output_hidden_states.
-
#pad_token_id ⇒ Object
readonly
Returns the value of attribute pad_token_id.
-
#problem_type ⇒ Object
Returns the value of attribute problem_type.
-
#pruned_heads ⇒ Object
readonly
Returns the value of attribute pruned_heads.
-
#tie_encoder_decoder ⇒ Object
readonly
Returns the value of attribute tie_encoder_decoder.
-
#tie_word_embeddings ⇒ Object
readonly
Returns the value of attribute tie_word_embeddings.
-
#tokenizer_class ⇒ Object
readonly
Returns the value of attribute tokenizer_class.
Class Method Summary collapse
- .from_dict(config_dict, **kwargs) ⇒ Object
- .from_pretrained(pretrained_model_name_or_path, cache_dir: nil, force_download: false, local_files_only: false, token: nil, revision: "main", **kwargs) ⇒ Object
- .get_config_dict(pretrained_model_name_or_path, **kwargs) ⇒ Object
Instance Method Summary collapse
- #_attn_implementation ⇒ Object
- #_dict ⇒ Object
- #getattr(key, default) ⇒ Object
- #hasattr(key) ⇒ Object
-
#initialize(**kwargs) ⇒ PretrainedConfig
constructor
A new instance of PretrainedConfig.
-
#method_missing(m, *args, **kwargs) ⇒ Object
TODO support setter.
- #name_or_path ⇒ Object
- #name_or_path=(value) ⇒ Object
- #num_labels ⇒ Object
- #num_labels=(num_labels) ⇒ Object
-
#respond_to_missing?(m, include_private = true) ⇒ Boolean
TODO support setter.
- #to_dict ⇒ Object
- #to_diff_dict ⇒ Object
- #to_json_string(use_diff: true) ⇒ Object
- #to_s ⇒ Object
- #use_return_dict ⇒ Object
Methods included from ClassAttribute
Constructor Details
#initialize(**kwargs) ⇒ PretrainedConfig
Returns a new instance of PretrainedConfig.
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# File 'lib/transformers/configuration_utils.rb', line 43 def initialize(**kwargs) @return_dict = kwargs.delete(:return_dict) { true } @output_hidden_states = kwargs.delete(:output_hidden_states) { false } @output_attentions = kwargs.delete(:output_attentions) { false } @pruned_heads = kwargs.delete(:pruned_heads) { {} } @tie_word_embeddings = kwargs.delete(:tie_word_embeddings) { true } @chunk_size_feed_forward = kwargs.delete(:chunk_size_feed_forward) { 0 } # Is decoder is used in encoder-decoder models to differentiate encoder from decoder @is_encoder_decoder = kwargs.delete(:is_encoder_decoder) { false } @is_decoder = kwargs.delete(:is_decoder) { false } @cross_attention_hidden_size = kwargs.delete(:cross_attention_hidden_size) @add_cross_attention = kwargs.delete(:add_cross_attention) { false } @tie_encoder_decoder = kwargs.delete(:tie_encoder_decoder) { false } # Fine-tuning task arguments @architectures = kwargs.delete(:architectures) @finetuning_task = kwargs.delete(:finetuning_task) @id2label = kwargs.delete(:id2label) @label2id = kwargs.delete(:label2id) if !@label2id.nil? && !@label2id.is_a?(Hash) raise ArgumentError, "Argument label2id should be a dictionary." end if !@id2label.nil? if !@id2label.is_a?(Hash) raise ArgumentError, "Argument id2label should be a dictionary." end num_labels = kwargs.delete(:num_labels) if !num_labels.nil? && id2label.length != num_labels raise Todo end @id2label = @id2label.transform_keys(&:to_i) # Keys are always strings in JSON so convert ids to int here. else self.num_labels = kwargs.delete(:num_labels) { 2 } end # Tokenizer arguments TODO: eventually tokenizer and models should share the same config @tokenizer_class = kwargs.delete(:tokenizer_class) @prefix = kwargs.delete(:prefix) @bos_token_id = kwargs.delete(:bos_token_id) @pad_token_id = kwargs.delete(:pad_token_id) @eos_token_id = kwargs.delete(:eos_token_id) @sep_token_id = kwargs.delete(:sep_token_id) # regression / multi-label classification @problem_type = kwargs.delete(:problem_type) # Name or path to the pretrained checkpoint @name_or_path = kwargs.delete(:name_or_path).to_s # Config hash @commit_hash = kwargs.delete(:_commit_hash) # Attention implementation to use, if relevant. @attn_implementation_internal = kwargs.delete(:attn_implementation) # Drop the transformers version info @transformers_version = kwargs.delete(:transformers_version) # Deal with gradient checkpointing # if kwargs[:gradient_checkpointing] == false # warn( # "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " + # "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " + # "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`." # ) # end kwargs.each do |k, v| instance_variable_set("@#{k}", v) end end |
Dynamic Method Handling
This class handles dynamic methods through the method_missing method
#method_missing(m, *args, **kwargs) ⇒ Object
TODO support setter
24 25 26 27 28 29 30 |
# File 'lib/transformers/configuration_utils.rb', line 24 def method_missing(m, *args, **kwargs) if self.class.attribute_map.include?(m) instance_variable_get("@#{self.class.attribute_map[m]}") else super end end |
Instance Attribute Details
#_commit_hash ⇒ Object (readonly)
Returns the value of attribute _commit_hash.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def _commit_hash @_commit_hash end |
#add_cross_attention ⇒ Object (readonly)
Returns the value of attribute add_cross_attention.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def add_cross_attention @add_cross_attention end |
#architectures ⇒ Object (readonly)
Returns the value of attribute architectures.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def architectures @architectures end |
#chunk_size_feed_forward ⇒ Object (readonly)
Returns the value of attribute chunk_size_feed_forward.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def chunk_size_feed_forward @chunk_size_feed_forward end |
#id2label ⇒ Object (readonly)
Returns the value of attribute id2label.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def id2label @id2label end |
#is_decoder ⇒ Object (readonly)
Returns the value of attribute is_decoder.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def is_decoder @is_decoder end |
#is_encoder_decoder ⇒ Object (readonly)
Returns the value of attribute is_encoder_decoder.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def is_encoder_decoder @is_encoder_decoder end |
#output_attentions ⇒ Object (readonly)
Returns the value of attribute output_attentions.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def output_attentions @output_attentions end |
#output_hidden_states ⇒ Object (readonly)
Returns the value of attribute output_hidden_states.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def output_hidden_states @output_hidden_states end |
#pad_token_id ⇒ Object (readonly)
Returns the value of attribute pad_token_id.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def pad_token_id @pad_token_id end |
#problem_type ⇒ Object
Returns the value of attribute problem_type.
41 42 43 |
# File 'lib/transformers/configuration_utils.rb', line 41 def problem_type @problem_type end |
#pruned_heads ⇒ Object (readonly)
Returns the value of attribute pruned_heads.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def pruned_heads @pruned_heads end |
#tie_encoder_decoder ⇒ Object (readonly)
Returns the value of attribute tie_encoder_decoder.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def tie_encoder_decoder @tie_encoder_decoder end |
#tie_word_embeddings ⇒ Object (readonly)
Returns the value of attribute tie_word_embeddings.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def @tie_word_embeddings end |
#tokenizer_class ⇒ Object (readonly)
Returns the value of attribute tokenizer_class.
37 38 39 |
# File 'lib/transformers/configuration_utils.rb', line 37 def tokenizer_class @tokenizer_class end |
Class Method Details
.from_dict(config_dict, **kwargs) ⇒ Object
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
# File 'lib/transformers/configuration_utils.rb', line 230 def from_dict(config_dict, **kwargs) return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false } # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update. if kwargs.include?(:_commit_hash) && config_dict.include?(:_commit_hash) kwargs[:_commit_hash] = config_dict[:_commit_hash] end config = new(**config_dict) to_remove = [] kwargs.each do |key, value| if config.respond_to?("#{key}=") config.public_send("#{key}=", value) end if key != :torch_dtype to_remove << key end end to_remove.each do |key| kwargs.delete(key) end Transformers.logger.info("Model config #{config}") if return_unused_kwargs [config, kwargs] else config end end |
.from_pretrained(pretrained_model_name_or_path, cache_dir: nil, force_download: false, local_files_only: false, token: nil, revision: "main", **kwargs) ⇒ Object
216 217 218 219 220 221 222 223 224 225 226 227 228 |
# File 'lib/transformers/configuration_utils.rb', line 216 def from_pretrained( pretrained_model_name_or_path, cache_dir: nil, force_download: false, local_files_only: false, token: nil, revision: "main", **kwargs ) config_dict, kwargs = get_config_dict(pretrained_model_name_or_path, **kwargs) from_dict(config_dict, **kwargs) end |
.get_config_dict(pretrained_model_name_or_path, **kwargs) ⇒ Object
261 262 263 264 265 266 |
# File 'lib/transformers/configuration_utils.rb', line 261 def get_config_dict(pretrained_model_name_or_path, **kwargs) # Get config dict associated with the base config file config_dict, kwargs = _get_config_dict(pretrained_model_name_or_path, **kwargs) [config_dict, kwargs] end |
Instance Method Details
#_attn_implementation ⇒ Object
135 136 137 138 139 140 141 142 143 144 145 146 147 |
# File 'lib/transformers/configuration_utils.rb', line 135 def _attn_implementation # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) if instance_variable_defined?(:@attn_implementation_internal) if instance_variable_get(:@attn_implementation_internal).nil? # `config.attn_implementation` should never be None, for backward compatibility. "eager" else @attn_implementation_internal end else "eager" end end |
#_dict ⇒ Object
175 176 177 |
# File 'lib/transformers/configuration_utils.rb', line 175 def _dict instance_variables.to_h { |k| [k[1..].to_sym, instance_variable_get(k)] } end |
#getattr(key, default) ⇒ Object
201 202 203 204 205 206 207 208 209 |
# File 'lib/transformers/configuration_utils.rb', line 201 def getattr(key, default) if respond_to?(key) public_send(key) elsif instance_variable_defined?("@#{key}") instance_variable_get("@#{key}") else default end end |
#hasattr(key) ⇒ Object
211 212 213 |
# File 'lib/transformers/configuration_utils.rb', line 211 def hasattr(key) respond_to?(key) || instance_variable_defined?("@#{key}") end |
#name_or_path ⇒ Object
116 117 118 |
# File 'lib/transformers/configuration_utils.rb', line 116 def name_or_path @name_or_path end |
#name_or_path=(value) ⇒ Object
120 121 122 |
# File 'lib/transformers/configuration_utils.rb', line 120 def name_or_path=(value) @name_or_path = value.to_s end |
#num_labels ⇒ Object
124 125 126 |
# File 'lib/transformers/configuration_utils.rb', line 124 def num_labels @id2label.length end |
#num_labels=(num_labels) ⇒ Object
128 129 130 131 132 133 |
# File 'lib/transformers/configuration_utils.rb', line 128 def num_labels=(num_labels) if @id2label.nil? || @id2label.length != num_labels @id2label = num_labels.times.to_h { |i| [i, "LABEL_#{i}"] } @label2id = @id2label.invert end end |
#respond_to_missing?(m, include_private = true) ⇒ Boolean
TODO support setter
33 34 35 |
# File 'lib/transformers/configuration_utils.rb', line 33 def respond_to_missing?(m, include_private = true) self.class.attribute_map.include?(m) || super end |
#to_dict ⇒ Object
179 180 181 182 183 184 185 186 187 188 189 190 |
# File 'lib/transformers/configuration_utils.rb', line 179 def to_dict output = Copy.deepcopy(_dict) output[:model_type] = self.class.model_type output.delete(:_auto_class) output.delete(:_commit_hash) output.delete(:_attn_implementation_internal) # Transformers version when serializing the model output[:transformers_version] = VERSION output end |
#to_diff_dict ⇒ Object
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
# File 'lib/transformers/configuration_utils.rb', line 157 def to_diff_dict config_dict = to_dict # get the default config dict default_config_dict = PretrainedConfig.new.to_dict serializable_config_dict = {} config_dict.each do |key, value| key = :_name_or_path if key == :name_or_path if !default_config_dict.include?(key) || value != default_config_dict[key] || key == :transformers_version serializable_config_dict[key] = value end end serializable_config_dict end |
#to_json_string(use_diff: true) ⇒ Object
192 193 194 195 196 197 198 199 |
# File 'lib/transformers/configuration_utils.rb', line 192 def to_json_string(use_diff: true) if use_diff == true config_dict = to_diff_dict else config_dict = to_dict end JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n" end |
#to_s ⇒ Object
153 154 155 |
# File 'lib/transformers/configuration_utils.rb', line 153 def to_s "#{self.class.name} #{to_json_string}" end |
#use_return_dict ⇒ Object
149 150 151 |
# File 'lib/transformers/configuration_utils.rb', line 149 def use_return_dict @return_dict end |