31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
# File 'lib/transformers/pipelines/base.rb', line 31
def self.infer_framework_load_model(
model,
config,
model_classes: nil,
task: nil,
framework: nil,
**model_kwargs
)
if model.is_a?(String)
model_kwargs[:_from_pipeline] = task
class_tuple = []
look_pt = true
if model_classes
if look_pt
class_tuple = class_tuple + model_classes.fetch("pt", AutoModel)
end
end
if config.architectures
classes = []
config.architectures.each do |architecture|
if look_pt
_class = Transformers.const_get(architecture)
if !_class.nil?
classes << _class
end
end
end
class_tuple = class_tuple + classes
end
if class_tuple.length == 0
raise ArgumentError, "Pipeline cannot infer suitable model classes from #{model}"
end
class_tuple.each do |model_class|
raise Error, "Invalid auto model class: #{model_class}" unless model_class < BaseAutoModelClass
kwargs = model_kwargs.dup
begin
model = model_class.from_pretrained(model, **kwargs)
if model.respond_to?(:eval)
model = model.eval
end
break
rescue
raise
end
end
end
if framework.nil?
framework = Utils.infer_framework(model.class)
end
[framework, model]
end
|