Module: Transformers::ModelingUtils

Defined in:
lib/transformers/modeling_utils.rb

Constant Summary collapse

TORCH_INIT_FUNCTIONS =
{
  "uniform!" => Torch::NN::Init.method(:uniform!),
  "normal!" => Torch::NN::Init.method(:normal!),
  # "trunc_normal!" => Torch::NN::Init.method(:trunc_normal!),
  "constant!" => Torch::NN::Init.method(:constant!),
  "xavier_uniform!" => Torch::NN::Init.method(:xavier_uniform!),
  "xavier_normal!" => Torch::NN::Init.method(:xavier_normal!),
  "kaiming_uniform!" => Torch::NN::Init.method(:kaiming_uniform!),
  "kaiming_normal!" => Torch::NN::Init.method(:kaiming_normal!)
  # "uniform" => Torch::NN::Init.method(:uniform),
  # "normal" => Torch::NN::Init.method(:normal),
  # "xavier_uniform" => Torch::NN::Init.method(:xavier_uniform),
  # "xavier_normal" => Torch::NN::Init.method(:xavier_normal),
  # "kaiming_uniform" => Torch::NN::Init.method(:kaiming_uniform),
  # "kaiming_normal" => Torch::NN::Init.method(:kaiming_normal)
}

Class Method Summary collapse

Class Method Details

.no_init_weightsObject

private note: this improves loading time significantly, but is not thread-safe!



37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# File 'lib/transformers/modeling_utils.rb', line 37

def self.no_init_weights
  return yield unless Transformers.fast_init

  _skip_init = lambda do |*args, **kwargs|
    # pass
  end
  # Save the original initialization functions
  TORCH_INIT_FUNCTIONS.each do |name, init_func|
    Torch::NN::Init.singleton_class.undef_method(name)
    Torch::NN::Init.define_singleton_method(name, &_skip_init)
  end
  yield
ensure
  # Restore the original initialization functions
  TORCH_INIT_FUNCTIONS.each do |name, init_func|
    Torch::NN::Init.singleton_class.undef_method(name)
    Torch::NN::Init.define_singleton_method(name, init_func)
  end
end