Module: Torch::Native::Dispatcher

Defined in:
lib/torch/native/dispatcher.rb

Class Method Summary collapse

Class Method Details

.bindObject



17
18
19
20
21
22
# File 'lib/torch/native/dispatcher.rb', line 17

def bind
  functions = Generator.grouped_functions
  bind_functions(::Torch, :define_singleton_method, functions[:torch])
  bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
  bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
end

.bind_functions(context, def_method, functions) ⇒ Object



24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# File 'lib/torch/native/dispatcher.rb', line 24

def bind_functions(context, def_method, functions)
  functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
    if def_method == :define_method
      funcs.map! { |f| Function.new(f.function) }
      funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
    end

    defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
    next if defined && name != "clone"

    parser = Parser.new(funcs)

    context.send(def_method, name) do |*args, **options|
      result = parser.parse(args, options)
      raise ArgumentError, result[:error] if result[:error]
      send(result[:name], *result[:args])
    end
  end
end