Class: Torch::NN::Module

Inherits:
Object
  • Object
show all
Includes:
Utils
Defined in:
lib/torch/nn/module.rb

Instance Method Summary collapse

Methods included from Utils

#_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initializeModule

Returns a new instance of Module.



6
7
8
9
10
11
# File 'lib/torch/nn/module.rb', line 6

def initialize
  @training = true
  @parameters = {}
  @buffers = {}
  @modules = {}
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method

#method_missing(method, *args, &block) ⇒ Object



260
261
262
263
264
265
266
267
268
269
270
271
# File 'lib/torch/nn/module.rb', line 260

def method_missing(method, *args, &block)
  name = method.to_s
  if named_parameters.key?(name)
    named_parameters[name]
  elsif named_buffers.key?(name)
    named_buffers[name]
  elsif named_modules.key?(name)
    named_modules[name]
  else
    super
  end
end

Instance Method Details

#_apply(fn) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# File 'lib/torch/nn/module.rb', line 33

def _apply(fn)
  children.each do |mod|
    mod._apply(fn)
  end

  instance_variables.each do |key|
    param = instance_variable_get(key)
    if param.is_a?(Parameter)
      param_applied = nil
      Torch.no_grad do
        param_applied = fn.call(param)
      end
      # TODO should_use_set_data
      instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad))

      if param.grad
        grad_applied = nil
        Torch.no_grad do
          grad_applied = fn.call(param.grad)
        end
        # TODO should_use_set_data
        instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad)
      end
    end
  end

  @buffers.each_key do |k|
    buf = @buffers[k]
    @buffers[k] = fn.call(buf) unless buf.nil?
  end

  self
end

#add_module(name, mod) ⇒ Object



28
29
30
31
# File 'lib/torch/nn/module.rb', line 28

def add_module(name, mod)
  # TODO add checks
  @modules[name] = mod
end

#apply(fn) ⇒ Object



67
68
69
70
71
72
73
# File 'lib/torch/nn/module.rb', line 67

def apply(fn)
  children.each do |mod|
    mod.apply(fn)
  end
  fn.call(self)
  self
end

#buffersObject



166
167
168
# File 'lib/torch/nn/module.rb', line 166

def buffers
  named_buffers.values
end

#call(*input, **kwargs) ⇒ Object



109
110
111
# File 'lib/torch/nn/module.rb', line 109

def call(*input, **kwargs)
  forward(*input, **kwargs)
end

#childrenObject



174
175
176
# File 'lib/torch/nn/module.rb', line 174

def children
  named_children.values
end

#cpuObject



80
81
82
# File 'lib/torch/nn/module.rb', line 80

def cpu
  _apply ->(t) { t.cpu }
end

#cudaObject

TODO add device



76
77
78
# File 'lib/torch/nn/module.rb', line 76

def cuda
  _apply ->(t) { t.cuda }
end

#doubleObject



92
93
94
# File 'lib/torch/nn/module.rb', line 92

def double
  _apply ->(t) { t.floating_point? ? t.double : t }
end

#evalObject



220
221
222
# File 'lib/torch/nn/module.rb', line 220

def eval
  train(false)
end

#floatObject



88
89
90
# File 'lib/torch/nn/module.rb', line 88

def float
  _apply ->(t) { t.floating_point? ? t.float : t }
end

#forwardObject

Raises:

  • (NotImplementedError)


13
14
15
# File 'lib/torch/nn/module.rb', line 13

def forward
  raise NotImplementedError
end

#halfObject



96
97
98
# File 'lib/torch/nn/module.rb', line 96

def half
  _apply ->(t) { t.floating_point? ? t.half : t }
end

#inspectObject



244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# File 'lib/torch/nn/module.rb', line 244

def inspect
  name = self.class.name.split("::").last
  if named_children.empty?
    "#{name}(#{extra_inspect})"
  else
    str = String.new
    str << "#{name}(\n"
    named_children.each do |name, mod|
      mod_str = mod.inspect
      mod_str = mod_str.lines.join("  ")
      str << "  (#{name}): #{mod_str}\n"
    end
    str << ")"
  end
end

#load_state_dict(state_dict) ⇒ Object

TODO add strict option TODO match PyTorch behavior



123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# File 'lib/torch/nn/module.rb', line 123

def load_state_dict(state_dict)
  state_dict.each do |k, input_param|
    k1, k2 = k.split(".", 2)
    mod = named_modules[k1]
    if mod.is_a?(Module)
      param = mod.named_parameters[k2]
      if param.is_a?(Parameter)
        Torch.no_grad do
          param.copy!(input_param)
        end
      else
        raise Error, "Unknown parameter: #{k1}"
      end
    else
      raise Error, "Unknown module: #{k1}"
    end
  end

  # TODO return missing keys and unexpected keys
  nil
end

#modulesObject



190
191
192
# File 'lib/torch/nn/module.rb', line 190

def modules
  named_modules.values
end

#named_buffersObject



170
171
172
# File 'lib/torch/nn/module.rb', line 170

def named_buffers
  @buffers || {}
end

#named_childrenObject



178
179
180
181
182
183
184
185
186
187
188
# File 'lib/torch/nn/module.rb', line 178

def named_children
  modules = {}
  instance_variables.each do |name|
    mod = instance_variable_get(name)
    modules[name[1..-1]] = mod if mod.is_a?(Module)
  end
  @modules.each do |name, mod|
    modules[name] = mod
  end
  modules
end

#named_modules(memo: nil, prefix: "") ⇒ Object

TODO return enumerator?



195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# File 'lib/torch/nn/module.rb', line 195

def named_modules(memo: nil, prefix: "")
  ret = {}
  memo ||= Set.new
  unless memo.include?(self)
    memo << self
    ret[prefix] = self
    named_children.each do |name, mod|
      next unless mod.is_a?(Module)
      submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
      mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
        ret[m[0]] = m[1]
      end
    end
  end
  ret
end

#named_parameters(prefix: "", recurse: true) ⇒ Object



149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# File 'lib/torch/nn/module.rb', line 149

def named_parameters(prefix: "", recurse: true)
  params = {}
  if recurse
    named_children.each do |name, mod|
      params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
    end
  end
  instance_variables.each do |name|
    param = instance_variable_get(name)
    params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
  end
  @parameters.each do |name, param|
    params[[prefix, name].join] = param if param
  end
  params
end

#parametersObject



145
146
147
# File 'lib/torch/nn/module.rb', line 145

def parameters
  named_parameters.values
end

#register_buffer(name, tensor) ⇒ Object



17
18
19
20
21
# File 'lib/torch/nn/module.rb', line 17

def register_buffer(name, tensor)
  # TODO add checks
  @buffers[name] = tensor
  instance_variable_set("@#{name}", tensor)
end

#register_parameter(name, param) ⇒ Object



23
24
25
26
# File 'lib/torch/nn/module.rb', line 23

def register_parameter(name, param)
  # TODO add checks
  @parameters[name] = param
end

#requires_grad!(requires_grad: true) ⇒ Object



224
225
226
227
228
229
# File 'lib/torch/nn/module.rb', line 224

def requires_grad!(requires_grad: true)
  parameters.each do |p|
    p.requires_grad!(requires_grad)
  end
  self
end

#respond_to?(method, include_private = false) ⇒ Boolean

Returns:

  • (Boolean)


273
274
275
276
# File 'lib/torch/nn/module.rb', line 273

def respond_to?(method, include_private = false)
  name = method.to_s
  named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end

#share_memoryObject



240
241
242
# File 'lib/torch/nn/module.rb', line 240

def share_memory
  _apply ->(t) { t.share_memory! }
end

#state_dict(destination: nil) ⇒ Object



113
114
115
116
117
118
119
# File 'lib/torch/nn/module.rb', line 113

def state_dict(destination: nil)
  destination ||= {}
  named_parameters.each do |k, v|
    destination[k] = v
  end
  destination
end

#to(device) ⇒ Object

modifies in-place



101
102
103
104
105
106
107
# File 'lib/torch/nn/module.rb', line 101

def to(device)
  convert = lambda do |t|
    t.to(device)
  end

  _apply(convert)
end

#train(mode = true) ⇒ Object



212
213
214
215
216
217
218
# File 'lib/torch/nn/module.rb', line 212

def train(mode = true)
  @training = mode
  children.each do |mod|
    mod.train(mode)
  end
  self
end

#type(dst_type) ⇒ Object



84
85
86
# File 'lib/torch/nn/module.rb', line 84

def type(dst_type)
  _apply ->(t) { t.type(dst_type) }
end

#zero_gradObject



231
232
233
234
235
236
237
238
# File 'lib/torch/nn/module.rb', line 231

def zero_grad
  parameters.each do |param|
    if param.grad
      param.grad.detach!
      param.grad.zero!
    end
  end
end