Class: Torch::NN::Module

Inherits:
Object
  • Object
show all
Defined in:
lib/torch/nn/module.rb

Instance Method Summary collapse

Constructor Details

#initializeModule

Returns a new instance of Module.



4
5
6
7
8
9
# File 'lib/torch/nn/module.rb', line 4

def initialize
  @training = true
  @parameters = {}
  @buffers = {}
  @modules = {}
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method

#method_missing(method, *args, &block) ⇒ Object



186
187
188
189
190
191
192
193
194
195
196
197
# File 'lib/torch/nn/module.rb', line 186

def method_missing(method, *args, &block)
  name = method.to_s
  if named_parameters.key?(name)
    named_parameters[name]
  elsif named_buffers.key?(name)
    named_buffers[name]
  elsif named_modules.key?(name)
    named_modules[name]
  else
    super
  end
end

Instance Method Details

#_apply(fn) ⇒ Object



30
31
32
33
34
35
36
# File 'lib/torch/nn/module.rb', line 30

def _apply(fn)
  children.each do |mod|
    mod._apply(fn)
  end
  # TODO apply to more objects
  self
end

#add_module(name, mod) ⇒ Object



25
26
27
28
# File 'lib/torch/nn/module.rb', line 25

def add_module(name, mod)
  # TODO add checks
  @modules[name] = mod
end

#apply(fn) ⇒ Object



38
39
40
41
42
43
44
# File 'lib/torch/nn/module.rb', line 38

def apply(fn)
  children.each do |mod|
    mod.apply(fn)
  end
  fn.call(self)
  self
end

#buffersObject



108
109
110
# File 'lib/torch/nn/module.rb', line 108

def buffers
  named_buffers.values
end

#call(*input) ⇒ Object



79
80
81
# File 'lib/torch/nn/module.rb', line 79

def call(*input)
  forward(*input)
end

#childrenObject



116
117
118
# File 'lib/torch/nn/module.rb', line 116

def children
  named_children.values
end

#cpuObject



50
51
52
# File 'lib/torch/nn/module.rb', line 50

def cpu
  _apply ->(t) { t.cpu }
end

#cuda(device: nil) ⇒ Object



46
47
48
# File 'lib/torch/nn/module.rb', line 46

def cuda(device: nil)
  _apply ->(t) { t.cuda(device) }
end

#doubleObject



62
63
64
# File 'lib/torch/nn/module.rb', line 62

def double
  _apply ->(t) { t.floating_point? ? t.double : t }
end

#evalObject



148
149
150
# File 'lib/torch/nn/module.rb', line 148

def eval
  train(false)
end

#floatObject



58
59
60
# File 'lib/torch/nn/module.rb', line 58

def float
  _apply ->(t) { t.floating_point? ? t.float : t }
end

#forwardObject

Raises:

  • (NotImplementedError)


11
12
13
# File 'lib/torch/nn/module.rb', line 11

def forward
  raise NotImplementedError
end

#halfObject



66
67
68
# File 'lib/torch/nn/module.rb', line 66

def half
  _apply ->(t) { t.floating_point? ? t.half : t }
end

#inspectObject



172
173
174
175
176
177
178
179
180
181
182
183
184
# File 'lib/torch/nn/module.rb', line 172

def inspect
  name = self.class.name.split("::").last
  if children.empty?
    "#{name}(#{extra_inspect})"
  else
    str = String.new
    str << "#{name}(\n"
    children.each do |name, mod|
      str << "  (#{name}): #{mod.inspect}\n"
    end
    str << ")"
  end
end

#modulesObject



132
133
134
# File 'lib/torch/nn/module.rb', line 132

def modules
  named_modules.values
end

#named_buffersObject



112
113
114
# File 'lib/torch/nn/module.rb', line 112

def named_buffers
  @buffers || {}
end

#named_childrenObject



120
121
122
123
124
125
126
127
128
129
130
# File 'lib/torch/nn/module.rb', line 120

def named_children
  modules = {}
  instance_variables.each do |name|
    mod = instance_variable_get(name)
    modules[name[1..-1]] = mod if mod.is_a?(Module)
  end
  @modules.each do |name, mod|
    modules[name] = mod
  end
  modules
end

#named_modulesObject



136
137
138
# File 'lib/torch/nn/module.rb', line 136

def named_modules
  {"" => self}.merge(named_children)
end

#named_parameters(prefix: "", recurse: true) ⇒ Object



91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# File 'lib/torch/nn/module.rb', line 91

def named_parameters(prefix: "", recurse: true)
  params = {}
  if recurse
    named_children.each do |name, mod|
      params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
    end
  end
  instance_variables.each do |name|
    param = instance_variable_get(name)
    params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
  end
  @parameters.each do |name, param|
    params[[prefix, name].join] = param
  end
  params
end

#parametersObject



87
88
89
# File 'lib/torch/nn/module.rb', line 87

def parameters
  named_parameters.values
end

#register_buffer(name, tensor) ⇒ Object



15
16
17
18
# File 'lib/torch/nn/module.rb', line 15

def register_buffer(name, tensor)
  # TODO add checks
  @buffers[name] = tensor
end

#register_parameter(name, param) ⇒ Object



20
21
22
23
# File 'lib/torch/nn/module.rb', line 20

def register_parameter(name, param)
  # TODO add checks
  @parameters[name] = param
end

#requires_grad!(requires_grad: true) ⇒ Object



152
153
154
155
156
157
# File 'lib/torch/nn/module.rb', line 152

def requires_grad!(requires_grad: true)
  parameters.each do |p|
    p.requires_grad!(requires_grad)
  end
  self
end

#respond_to?(method, include_private = false) ⇒ Boolean

Returns:

  • (Boolean)


199
200
201
202
# File 'lib/torch/nn/module.rb', line 199

def respond_to?(method, include_private = false)
  name = method.to_s
  named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end

#share_memoryObject



168
169
170
# File 'lib/torch/nn/module.rb', line 168

def share_memory
  _apply ->(t) { t.share_memory! }
end

#state_dictObject

Raises:



83
84
85
# File 'lib/torch/nn/module.rb', line 83

def state_dict
  raise NotImplementedYet
end

#to(device) ⇒ Object

modifies in-place



71
72
73
74
75
76
77
# File 'lib/torch/nn/module.rb', line 71

def to(device)
  convert = lambda do |t|
    t.to(device)
  end

  _apply(convert)
end

#train(mode = true) ⇒ Object



140
141
142
143
144
145
146
# File 'lib/torch/nn/module.rb', line 140

def train(mode = true)
  @training = mode
  children.each do |mod|
    mod.train(mode)
  end
  self
end

#type(dst_type) ⇒ Object



54
55
56
# File 'lib/torch/nn/module.rb', line 54

def type(dst_type)
  _apply ->(t) { t.type(dst_type) }
end

#zero_gradObject



159
160
161
162
163
164
165
166
# File 'lib/torch/nn/module.rb', line 159

def zero_grad
  parameters.each do |param|
    if param.grad
      param.grad.detach!
      param.grad.zero!
    end
  end
end