Module: Transformers::Utils

Defined in:
lib/transformers/utils/hub.rb,
lib/transformers/utils/generic.rb,
lib/transformers/utils/import_utils.rb

Defined Under Namespace

Modules: Hub

Constant Summary collapse

ENV_VARS_TRUE_VALUES =
["1", "ON", "YES", "TRUE"]
VISION_IMPORT_ERROR =
<<~MSG
%s requires the `ruby-vips` gem
MSG
BACKENDS_MAPPING =
{
  "vision" => [singleton_method(:is_vision_available), VISION_IMPORT_ERROR]
}

Class Method Summary collapse

Class Method Details

._is_numo(x) ⇒ Object



83
84
85
# File 'lib/transformers/utils/generic.rb', line 83

def self._is_numo(x)
  x.is_a?(Numo::NArray)
end

._is_torch(x) ⇒ Object



91
92
93
# File 'lib/transformers/utils/generic.rb', line 91

def self._is_torch(x)
  x.is_a?(Torch::Tensor)
end

._is_torch_device(x) ⇒ Object



99
100
101
# File 'lib/transformers/utils/generic.rb', line 99

def self._is_torch_device(x)
  x.is_a?(Torch::Device)
end

.infer_framework(model_class) ⇒ Object



75
76
77
78
79
80
81
# File 'lib/transformers/utils/generic.rb', line 75

def self.infer_framework(model_class)
  if model_class < Torch::NN::Module
    "pt"
  else
    raise TypeError, "Could not infer framework from class #{model_class}."
  end
end

.is_numo_array(x) ⇒ Object



87
88
89
# File 'lib/transformers/utils/generic.rb', line 87

def self.is_numo_array(x)
  _is_numo(x)
end

.is_torch_device(x) ⇒ Object



103
104
105
# File 'lib/transformers/utils/generic.rb', line 103

def self.is_torch_device(x)
  _is_torch_device(x)
end

.is_torch_tensor(x) ⇒ Object



95
96
97
# File 'lib/transformers/utils/generic.rb', line 95

def self.is_torch_tensor(x)
  _is_torch(x)
end

.is_vision_availableObject



33
34
35
# File 'lib/transformers/utils/import_utils.rb', line 33

def self.is_vision_available
  defined?(Vips)
end

.requires_backends(obj, backends) ⇒ Object



19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/transformers/utils/import_utils.rb', line 19

def self.requires_backends(obj, backends)
  if !backends.is_a?(Array)
    backends = [backends]
  end

  name = obj.is_a?(Symbol) ? obj : obj.class.name

  checks = backends.map { |backend| BACKENDS_MAPPING.fetch(backend) }
  failed = checks.filter_map { |available, msg| format(msg, name) if !available.call }
  if failed.any?
    raise Error, failed.join("")
  end
end