Class: Torch::NN::Module
- Inherits:
-
Object
show all
- Includes:
- Utils
- Defined in:
- lib/torch/nn/module.rb
Direct Known Subclasses
AdaptiveAvgPoolNd, AdaptiveMaxPoolNd, AvgPoolNd, BatchNorm, Bilinear, ConstantPadNd, ConvNd, CosineSimilarity, DropoutNd, Embedding, EmbeddingBag, Fold, GroupNorm, Hardshrink, Identity, LPPoolNd, LayerNorm, LeakyReLU, Linear, LocalResponseNorm, LogSigmoid, LogSoftmax, Loss, MaxPoolNd, MaxUnpoolNd, PReLU, PairwiseDistance, RNNBase, ReLU, ReflectionPadNd, ReplicationPadNd, Sequential, Sigmoid, Softmax, Softmax2d, Softmin, Softplus, Softshrink, Softsign, Tanh, Tanhshrink, Unfold
Instance Method Summary
collapse
Methods included from Utils
#_ntuple, #_pair, #_quadrupal, #_single, #_triple
Constructor Details
#initialize ⇒ Module
Returns a new instance of Module.
6
7
8
9
10
11
|
# File 'lib/torch/nn/module.rb', line 6
def initialize
@training = true
@parameters = {}
@buffers = {}
@modules = {}
end
|
Dynamic Method Handling
This class handles dynamic methods through the method_missing method
#method_missing(method, *args, &block) ⇒ Object
260
261
262
263
264
265
266
267
268
269
270
271
|
# File 'lib/torch/nn/module.rb', line 260
def method_missing(method, *args, &block)
name = method.to_s
if named_parameters.key?(name)
named_parameters[name]
elsif named_buffers.key?(name)
named_buffers[name]
elsif named_modules.key?(name)
named_modules[name]
else
super
end
end
|
Instance Method Details
#_apply(fn) ⇒ Object
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
# File 'lib/torch/nn/module.rb', line 33
def _apply(fn)
children.each do |mod|
mod._apply(fn)
end
instance_variables.each do |key|
param = instance_variable_get(key)
if param.is_a?(Parameter)
param_applied = nil
Torch.no_grad do
param_applied = fn.call(param)
end
instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad))
if param.grad
grad_applied = nil
Torch.no_grad do
grad_applied = fn.call(param.grad)
end
instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad)
end
end
end
@buffers.each_key do |k|
buf = @buffers[k]
@buffers[k] = fn.call(buf) unless buf.nil?
end
self
end
|
#add_module(name, mod) ⇒ Object
28
29
30
31
|
# File 'lib/torch/nn/module.rb', line 28
def add_module(name, mod)
@modules[name] = mod
end
|
#apply(fn) ⇒ Object
67
68
69
70
71
72
73
|
# File 'lib/torch/nn/module.rb', line 67
def apply(fn)
children.each do |mod|
mod.apply(fn)
end
fn.call(self)
self
end
|
#buffers ⇒ Object
166
167
168
|
# File 'lib/torch/nn/module.rb', line 166
def buffers
named_buffers.values
end
|
#call(*input, **kwargs) ⇒ Object
109
110
111
|
# File 'lib/torch/nn/module.rb', line 109
def call(*input, **kwargs)
forward(*input, **kwargs)
end
|
#children ⇒ Object
174
175
176
|
# File 'lib/torch/nn/module.rb', line 174
def children
named_children.values
end
|
#cpu ⇒ Object
80
81
82
|
# File 'lib/torch/nn/module.rb', line 80
def cpu
_apply ->(t) { t.cpu }
end
|
#cuda ⇒ Object
76
77
78
|
# File 'lib/torch/nn/module.rb', line 76
def cuda
_apply ->(t) { t.cuda }
end
|
#double ⇒ Object
92
93
94
|
# File 'lib/torch/nn/module.rb', line 92
def double
_apply ->(t) { t.floating_point? ? t.double : t }
end
|
#eval ⇒ Object
220
221
222
|
# File 'lib/torch/nn/module.rb', line 220
def eval
train(false)
end
|
#float ⇒ Object
88
89
90
|
# File 'lib/torch/nn/module.rb', line 88
def float
_apply ->(t) { t.floating_point? ? t.float : t }
end
|
#forward ⇒ Object
13
14
15
|
# File 'lib/torch/nn/module.rb', line 13
def forward
raise NotImplementedError
end
|
#half ⇒ Object
96
97
98
|
# File 'lib/torch/nn/module.rb', line 96
def half
_apply ->(t) { t.floating_point? ? t.half : t }
end
|
#inspect ⇒ Object
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
|
# File 'lib/torch/nn/module.rb', line 244
def inspect
name = self.class.name.split("::").last
if named_children.empty?
"#{name}(#{})"
else
str = String.new
str << "#{name}(\n"
named_children.each do |name, mod|
mod_str = mod.inspect
mod_str = mod_str.lines.join(" ")
str << " (#{name}): #{mod_str}\n"
end
str << ")"
end
end
|
#load_state_dict(state_dict) ⇒ Object
TODO add strict option TODO match PyTorch behavior
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
|
# File 'lib/torch/nn/module.rb', line 123
def load_state_dict(state_dict)
state_dict.each do |k, input_param|
k1, k2 = k.split(".", 2)
mod = named_modules[k1]
if mod.is_a?(Module)
param = mod.named_parameters[k2]
if param.is_a?(Parameter)
Torch.no_grad do
param.copy!(input_param)
end
else
raise Error, "Unknown parameter: #{k1}"
end
else
raise Error, "Unknown module: #{k1}"
end
end
nil
end
|
#modules ⇒ Object
190
191
192
|
# File 'lib/torch/nn/module.rb', line 190
def modules
named_modules.values
end
|
#named_buffers ⇒ Object
170
171
172
|
# File 'lib/torch/nn/module.rb', line 170
def named_buffers
@buffers || {}
end
|
#named_children ⇒ Object
178
179
180
181
182
183
184
185
186
187
188
|
# File 'lib/torch/nn/module.rb', line 178
def named_children
modules = {}
instance_variables.each do |name|
mod = instance_variable_get(name)
modules[name[1..-1]] = mod if mod.is_a?(Module)
end
@modules.each do |name, mod|
modules[name] = mod
end
modules
end
|
#named_modules(memo: nil, prefix: "") ⇒ Object
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
|
# File 'lib/torch/nn/module.rb', line 195
def named_modules(memo: nil, prefix: "")
ret = {}
memo ||= Set.new
unless memo.include?(self)
memo << self
ret[prefix] = self
named_children.each do |name, mod|
next unless mod.is_a?(Module)
submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
ret[m[0]] = m[1]
end
end
end
ret
end
|
#named_parameters(prefix: "", recurse: true) ⇒ Object
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
|
# File 'lib/torch/nn/module.rb', line 149
def named_parameters(prefix: "", recurse: true)
params = {}
if recurse
named_children.each do |name, mod|
params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
end
end
instance_variables.each do |name|
param = instance_variable_get(name)
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
end
@parameters.each do |name, param|
params[[prefix, name].join] = param if param
end
params
end
|
#parameters ⇒ Object
145
146
147
|
# File 'lib/torch/nn/module.rb', line 145
def parameters
named_parameters.values
end
|
#register_buffer(name, tensor) ⇒ Object
17
18
19
20
21
|
# File 'lib/torch/nn/module.rb', line 17
def register_buffer(name, tensor)
@buffers[name] = tensor
instance_variable_set("@#{name}", tensor)
end
|
#register_parameter(name, param) ⇒ Object
23
24
25
26
|
# File 'lib/torch/nn/module.rb', line 23
def register_parameter(name, param)
@parameters[name] = param
end
|
#requires_grad!(requires_grad: true) ⇒ Object
224
225
226
227
228
229
|
# File 'lib/torch/nn/module.rb', line 224
def requires_grad!(requires_grad: true)
parameters.each do |p|
p.requires_grad!(requires_grad)
end
self
end
|
#respond_to?(method, include_private = false) ⇒ Boolean
273
274
275
276
|
# File 'lib/torch/nn/module.rb', line 273
def respond_to?(method, include_private = false)
name = method.to_s
named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end
|
#share_memory ⇒ Object
240
241
242
|
# File 'lib/torch/nn/module.rb', line 240
def share_memory
_apply ->(t) { t.share_memory! }
end
|
#state_dict(destination: nil) ⇒ Object
113
114
115
116
117
118
119
|
# File 'lib/torch/nn/module.rb', line 113
def state_dict(destination: nil)
destination ||= {}
named_parameters.each do |k, v|
destination[k] = v
end
destination
end
|
#to(device) ⇒ Object
101
102
103
104
105
106
107
|
# File 'lib/torch/nn/module.rb', line 101
def to(device)
convert = lambda do |t|
t.to(device)
end
_apply(convert)
end
|
#train(mode = true) ⇒ Object
212
213
214
215
216
217
218
|
# File 'lib/torch/nn/module.rb', line 212
def train(mode = true)
@training = mode
children.each do |mod|
mod.train(mode)
end
self
end
|
#type(dst_type) ⇒ Object
84
85
86
|
# File 'lib/torch/nn/module.rb', line 84
def type(dst_type)
_apply ->(t) { t.type(dst_type) }
end
|
#zero_grad ⇒ Object
231
232
233
234
235
236
237
238
|
# File 'lib/torch/nn/module.rb', line 231
def zero_grad
parameters.each do |param|
if param.grad
param.grad.detach!
param.grad.zero!
end
end
end
|