Class: Transformers::PretrainedConfig

Inherits:
Object
  • Object
show all
Extended by:
ClassAttribute
Defined in:
lib/transformers/configuration_utils.rb

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods included from ClassAttribute

class_attribute

Constructor Details

#initialize(**kwargs) ⇒ PretrainedConfig

Returns a new instance of PretrainedConfig.



43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# File 'lib/transformers/configuration_utils.rb', line 43

def initialize(**kwargs)
  @return_dict = kwargs.delete(:return_dict) { true }
  @output_hidden_states = kwargs.delete(:output_hidden_states) { false }
  @output_attentions = kwargs.delete(:output_attentions) { false }
  @pruned_heads = kwargs.delete(:pruned_heads) { {} }
  @tie_word_embeddings = kwargs.delete(:tie_word_embeddings) { true }
  @chunk_size_feed_forward = kwargs.delete(:chunk_size_feed_forward) { 0 }

  # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
  @is_encoder_decoder = kwargs.delete(:is_encoder_decoder) { false }
  @is_decoder = kwargs.delete(:is_decoder) { false }
  @cross_attention_hidden_size = kwargs.delete(:cross_attention_hidden_size)
  @add_cross_attention = kwargs.delete(:add_cross_attention) { false }
  @tie_encoder_decoder = kwargs.delete(:tie_encoder_decoder) { false }

  # Fine-tuning task arguments
  @architectures = kwargs.delete(:architectures)
  @finetuning_task = kwargs.delete(:finetuning_task)
  @id2label = kwargs.delete(:id2label)
  @label2id = kwargs.delete(:label2id)
  if !@label2id.nil? && !@label2id.is_a?(Hash)
    raise ArgumentError, "Argument label2id should be a dictionary."
  end
  if !@id2label.nil?
    if !@id2label.is_a?(Hash)
      raise ArgumentError, "Argument id2label should be a dictionary."
    end
    num_labels = kwargs.delete(:num_labels)
    if !num_labels.nil? && id2label.length != num_labels
      raise Todo
    end
    @id2label = @id2label.transform_keys(&:to_i)
    # Keys are always strings in JSON so convert ids to int here.
  else
    self.num_labels = kwargs.delete(:num_labels) { 2 }
  end

  # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
  @tokenizer_class = kwargs.delete(:tokenizer_class)
  @prefix = kwargs.delete(:prefix)
  @bos_token_id = kwargs.delete(:bos_token_id)
  @pad_token_id = kwargs.delete(:pad_token_id)
  @eos_token_id = kwargs.delete(:eos_token_id)
  @sep_token_id = kwargs.delete(:sep_token_id)

  # regression / multi-label classification
  @problem_type = kwargs.delete(:problem_type)

  # Name or path to the pretrained checkpoint
  @name_or_path = kwargs.delete(:name_or_path).to_s
  # Config hash
  @commit_hash = kwargs.delete(:_commit_hash)

  # Attention implementation to use, if relevant.
  @attn_implementation_internal = kwargs.delete(:attn_implementation)

  # Drop the transformers version info
  @transformers_version = kwargs.delete(:transformers_version)

  # Deal with gradient checkpointing
  # if kwargs[:gradient_checkpointing] == false
  #   warn(
  #     "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " +
  #     "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " +
  #     "`Trainer` API, pass `gradient_checkpointing: true` in your `TrainingArguments`."
  #   )
  # end

  kwargs.each do |k, v|
    instance_variable_set("@#{k}", v)
  end
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method

#method_missing(m, *args, **kwargs) ⇒ Object

TODO support setter



24
25
26
27
28
29
30
# File 'lib/transformers/configuration_utils.rb', line 24

def method_missing(m, *args, **kwargs)
  if self.class.attribute_map.include?(m)
    instance_variable_get("@#{self.class.attribute_map[m]}")
  else
    super
  end
end

Instance Attribute Details

#_commit_hashObject (readonly)

Returns the value of attribute _commit_hash.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def _commit_hash
  @_commit_hash
end

#add_cross_attentionObject (readonly)

Returns the value of attribute add_cross_attention.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def add_cross_attention
  @add_cross_attention
end

#architecturesObject (readonly)

Returns the value of attribute architectures.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def architectures
  @architectures
end

#chunk_size_feed_forwardObject (readonly)

Returns the value of attribute chunk_size_feed_forward.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def chunk_size_feed_forward
  @chunk_size_feed_forward
end

#id2labelObject (readonly)

Returns the value of attribute id2label.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def id2label
  @id2label
end

#is_decoderObject (readonly)

Returns the value of attribute is_decoder.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def is_decoder
  @is_decoder
end

#is_encoder_decoderObject (readonly)

Returns the value of attribute is_encoder_decoder.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def is_encoder_decoder
  @is_encoder_decoder
end

#output_attentionsObject (readonly)

Returns the value of attribute output_attentions.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def output_attentions
  @output_attentions
end

#output_hidden_statesObject (readonly)

Returns the value of attribute output_hidden_states.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def output_hidden_states
  @output_hidden_states
end

#pad_token_idObject (readonly)

Returns the value of attribute pad_token_id.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def pad_token_id
  @pad_token_id
end

#problem_typeObject

Returns the value of attribute problem_type.



41
42
43
# File 'lib/transformers/configuration_utils.rb', line 41

def problem_type
  @problem_type
end

#pruned_headsObject (readonly)

Returns the value of attribute pruned_heads.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def pruned_heads
  @pruned_heads
end

#tie_encoder_decoderObject (readonly)

Returns the value of attribute tie_encoder_decoder.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def tie_encoder_decoder
  @tie_encoder_decoder
end

#tie_word_embeddingsObject (readonly)

Returns the value of attribute tie_word_embeddings.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def tie_word_embeddings
  @tie_word_embeddings
end

#tokenizer_classObject (readonly)

Returns the value of attribute tokenizer_class.



37
38
39
# File 'lib/transformers/configuration_utils.rb', line 37

def tokenizer_class
  @tokenizer_class
end

Class Method Details

.from_dict(config_dict, **kwargs) ⇒ Object



230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
# File 'lib/transformers/configuration_utils.rb', line 230

def from_dict(config_dict, **kwargs)
  return_unused_kwargs = kwargs.delete(:return_unused_kwargs) { false }

  # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
  if kwargs.include?(:_commit_hash) && config_dict.include?(:_commit_hash)
    kwargs[:_commit_hash] = config_dict[:_commit_hash]
  end

  config = new(**config_dict)

  to_remove = []
  kwargs.each do |key, value|
    if config.respond_to?("#{key}=")
      config.public_send("#{key}=", value)
    end
    if key != :torch_dtype
      to_remove << key
    end
  end
  to_remove.each do |key|
    kwargs.delete(key)
  end

  Transformers.logger.info("Model config #{config}")
  if return_unused_kwargs
    [config, kwargs]
  else
    config
  end
end

.from_pretrained(pretrained_model_name_or_path, cache_dir: nil, force_download: false, local_files_only: false, token: nil, revision: "main", **kwargs) ⇒ Object



216
217
218
219
220
221
222
223
224
225
226
227
228
# File 'lib/transformers/configuration_utils.rb', line 216

def from_pretrained(
  pretrained_model_name_or_path,
  cache_dir: nil,
  force_download: false,
  local_files_only: false,
  token: nil,
  revision: "main",
  **kwargs
)
  config_dict, kwargs = get_config_dict(pretrained_model_name_or_path, **kwargs)

  from_dict(config_dict, **kwargs)
end

.get_config_dict(pretrained_model_name_or_path, **kwargs) ⇒ Object



261
262
263
264
265
266
# File 'lib/transformers/configuration_utils.rb', line 261

def get_config_dict(pretrained_model_name_or_path, **kwargs)
  # Get config dict associated with the base config file
  config_dict, kwargs = _get_config_dict(pretrained_model_name_or_path, **kwargs)

  [config_dict, kwargs]
end

Instance Method Details

#_attn_implementationObject



135
136
137
138
139
140
141
142
143
144
145
146
147
# File 'lib/transformers/configuration_utils.rb', line 135

def _attn_implementation
  # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
  if instance_variable_defined?(:@attn_implementation_internal)
    if instance_variable_get(:@attn_implementation_internal).nil?
      # `config.attn_implementation` should never be None, for backward compatibility.
      "eager"
    else
      @attn_implementation_internal
    end
  else
    "eager"
  end
end

#_dictObject



175
176
177
# File 'lib/transformers/configuration_utils.rb', line 175

def _dict
  instance_variables.to_h { |k| [k[1..].to_sym, instance_variable_get(k)] }
end

#getattr(key, default) ⇒ Object



201
202
203
204
205
206
207
208
209
# File 'lib/transformers/configuration_utils.rb', line 201

def getattr(key, default)
  if respond_to?(key)
    public_send(key)
  elsif instance_variable_defined?("@#{key}")
    instance_variable_get("@#{key}")
  else
    default
  end
end

#hasattr(key) ⇒ Object



211
212
213
# File 'lib/transformers/configuration_utils.rb', line 211

def hasattr(key)
  respond_to?(key) || instance_variable_defined?("@#{key}")
end

#name_or_pathObject



116
117
118
# File 'lib/transformers/configuration_utils.rb', line 116

def name_or_path
  @name_or_path
end

#name_or_path=(value) ⇒ Object



120
121
122
# File 'lib/transformers/configuration_utils.rb', line 120

def name_or_path=(value)
  @name_or_path = value.to_s
end

#num_labelsObject



124
125
126
# File 'lib/transformers/configuration_utils.rb', line 124

def num_labels
  @id2label.length
end

#num_labels=(num_labels) ⇒ Object



128
129
130
131
132
133
# File 'lib/transformers/configuration_utils.rb', line 128

def num_labels=(num_labels)
  if @id2label.nil? || @id2label.length != num_labels
    @id2label = num_labels.times.to_h { |i| [i, "LABEL_#{i}"] }
    @label2id =  @id2label.invert
  end
end

#respond_to_missing?(m, include_private = true) ⇒ Boolean

TODO support setter

Returns:

  • (Boolean)


33
34
35
# File 'lib/transformers/configuration_utils.rb', line 33

def respond_to_missing?(m, include_private = true)
  self.class.attribute_map.include?(m) || super
end

#to_dictObject



179
180
181
182
183
184
185
186
187
188
189
190
# File 'lib/transformers/configuration_utils.rb', line 179

def to_dict
  output = Copy.deepcopy(_dict)
  output[:model_type] = self.class.model_type
  output.delete(:_auto_class)
  output.delete(:_commit_hash)
  output.delete(:_attn_implementation_internal)

  # Transformers version when serializing the model
  output[:transformers_version] = VERSION

  output
end

#to_diff_dictObject



157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# File 'lib/transformers/configuration_utils.rb', line 157

def to_diff_dict
  config_dict = to_dict

  # get the default config dict
  default_config_dict = PretrainedConfig.new.to_dict

  serializable_config_dict = {}

  config_dict.each do |key, value|
    key = :_name_or_path if key == :name_or_path
    if !default_config_dict.include?(key) || value != default_config_dict[key] || key == :transformers_version
      serializable_config_dict[key] = value
    end
  end

  serializable_config_dict
end

#to_json_string(use_diff: true) ⇒ Object



192
193
194
195
196
197
198
199
# File 'lib/transformers/configuration_utils.rb', line 192

def to_json_string(use_diff: true)
  if use_diff == true
    config_dict = to_diff_dict
  else
    config_dict = to_dict
  end
  JSON.pretty_generate(config_dict.sort_by { |k, _| k }.to_h) + "\n"
end

#to_sObject



153
154
155
# File 'lib/transformers/configuration_utils.rb', line 153

def to_s
  "#{self.class.name} #{to_json_string}"
end

#use_return_dictObject



149
150
151
# File 'lib/transformers/configuration_utils.rb', line 149

def use_return_dict
  @return_dict
end