Class: Torch::NN::Module
- Inherits:
-
Object
show all
- Includes:
- Utils
- Defined in:
- lib/torch/nn/module.rb
Direct Known Subclasses
AdaptiveAvgPoolNd, AdaptiveMaxPoolNd, AvgPoolNd, BatchNorm, Bilinear, ConstantPadNd, ConvNd, CosineSimilarity, DropoutNd, ELU, Embedding, EmbeddingBag, Fold, GELU, GroupNorm, Hardshrink, Identity, LPPoolNd, LayerNorm, LeakyReLU, Linear, LocalResponseNorm, LogSigmoid, LogSoftmax, Loss, MaxPoolNd, MaxUnpoolNd, ModuleList, MultiheadAttention, PReLU, PairwiseDistance, ParameterList, RNNBase, ReLU, ReflectionPadNd, ReplicationPadNd, Sequential, Sigmoid, Softmax, Softmax2d, Softmin, Softplus, Softshrink, Softsign, Tanh, Tanhshrink, Transformer, TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer, Unfold, Upsample
Instance Attribute Summary collapse
Instance Method Summary
collapse
-
#_apply(fn) ⇒ Object
-
#add_module(name, mod) ⇒ Object
-
#apply(fn) ⇒ Object
-
#buffers ⇒ Object
-
#call(*input, **kwargs) ⇒ Object
-
#children ⇒ Object
-
#cpu ⇒ Object
-
#cuda ⇒ Object
-
#deep_dup ⇒ Object
-
#double ⇒ Object
-
#eval ⇒ Object
-
#float ⇒ Object
-
#forward ⇒ Object
-
#half ⇒ Object
-
#initialize ⇒ Module
constructor
A new instance of Module.
-
#inspect ⇒ Object
-
#load_state_dict(state_dict, strict: true) ⇒ Object
-
#method_missing(method, *args, &block) ⇒ Object
-
#modules ⇒ Object
-
#named_buffers(prefix: "", recurse: false) ⇒ Object
TODO set recurse: true in 0.18.0.
-
#named_children ⇒ Object
-
#named_modules(memo: nil, prefix: "") ⇒ Object
-
#named_parameters(prefix: "", recurse: true) ⇒ Object
-
#parameters ⇒ Object
-
#register_buffer(name, tensor, persistent: true) ⇒ Object
-
#register_parameter(name, param) ⇒ Object
-
#requires_grad!(requires_grad: true) ⇒ Object
-
#respond_to?(method, include_private = false) ⇒ Boolean
-
#share_memory ⇒ Object
-
#state_dict(destination: nil, prefix: "") ⇒ Object
-
#to(device) ⇒ Object
-
#train(mode = true) ⇒ Object
-
#type(dst_type) ⇒ Object
-
#zero_grad ⇒ Object
Methods included from Utils
#_activation_fn, #_clones, #_ntuple, #_pair, #_quadrupal, #_single, #_triple
Constructor Details
#initialize ⇒ Module
Returns a new instance of Module.
8
9
10
11
12
13
14
|
# File 'lib/torch/nn/module.rb', line 8
def initialize
@training = true
@parameters = {}
@buffers = {}
@modules = {}
@non_persistent_buffers_set = Set.new
end
|
Dynamic Method Handling
This class handles dynamic methods through the method_missing method
#method_missing(method, *args, &block) ⇒ Object
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
|
# File 'lib/torch/nn/module.rb', line 305
def method_missing(method, *args, &block)
name = method.to_s
if named_parameters.key?(name)
named_parameters[name]
elsif named_buffers.key?(name)
named_buffers[name]
elsif named_modules.key?(name)
named_modules[name]
elsif method.end_with?("=") && named_modules.key?(method[0..-2])
if instance_variable_defined?("@#{method[0..-2]}")
instance_variable_set("@#{method[0..-2]}", *args)
else
raise NotImplementedYet
end
else
super
end
end
|
Instance Attribute Details
#training ⇒ Object
Returns the value of attribute training.
6
7
8
|
# File 'lib/torch/nn/module.rb', line 6
def training
@training
end
|
Instance Method Details
#_apply(fn) ⇒ Object
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
# File 'lib/torch/nn/module.rb', line 42
def _apply(fn)
children.each do |mod|
mod._apply(fn)
end
instance_variables.each do |key|
param = instance_variable_get(key)
if param.is_a?(Parameter)
param_applied = nil
Torch.no_grad do
param_applied = fn.call(param)
end
instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad))
if param.grad
grad_applied = nil
Torch.no_grad do
grad_applied = fn.call(param.grad)
end
instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad)
end
end
end
@buffers.each_key do |k|
buf = @buffers[k]
unless buf.nil?
@buffers[k] = fn.call(buf)
instance_variable_set("@#{k}", @buffers[k])
end
end
self
end
|
#add_module(name, mod) ⇒ Object
37
38
39
40
|
# File 'lib/torch/nn/module.rb', line 37
def add_module(name, mod)
@modules[name] = mod
end
|
#apply(fn) ⇒ Object
79
80
81
82
83
84
85
|
# File 'lib/torch/nn/module.rb', line 79
def apply(fn)
children.each do |mod|
mod.apply(fn)
end
fn.call(self)
self
end
|
#buffers ⇒ Object
196
197
198
|
# File 'lib/torch/nn/module.rb', line 196
def buffers
named_buffers.values
end
|
#call(*input, **kwargs) ⇒ Object
121
122
123
|
# File 'lib/torch/nn/module.rb', line 121
def call(*input, **kwargs)
forward(*input, **kwargs)
end
|
#children ⇒ Object
214
215
216
|
# File 'lib/torch/nn/module.rb', line 214
def children
named_children.values
end
|
#cpu ⇒ Object
92
93
94
|
# File 'lib/torch/nn/module.rb', line 92
def cpu
_apply ->(t) { t.cpu }
end
|
#cuda ⇒ Object
88
89
90
|
# File 'lib/torch/nn/module.rb', line 88
def cuda
_apply ->(t) { t.cuda }
end
|
#deep_dup ⇒ Object
300
301
302
303
|
# File 'lib/torch/nn/module.rb', line 300
def deep_dup
memo = {}
dup_value(self, memo)
end
|
#double ⇒ Object
104
105
106
|
# File 'lib/torch/nn/module.rb', line 104
def double
_apply ->(t) { t.floating_point? ? t.double : t }
end
|
#eval ⇒ Object
260
261
262
|
# File 'lib/torch/nn/module.rb', line 260
def eval
train(false)
end
|
#float ⇒ Object
100
101
102
|
# File 'lib/torch/nn/module.rb', line 100
def float
_apply ->(t) { t.floating_point? ? t.float : t }
end
|
#forward ⇒ Object
16
17
18
|
# File 'lib/torch/nn/module.rb', line 16
def forward
raise NotImplementedError
end
|
#half ⇒ Object
108
109
110
|
# File 'lib/torch/nn/module.rb', line 108
def half
_apply ->(t) { t.floating_point? ? t.half : t }
end
|
#inspect ⇒ Object
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
|
# File 'lib/torch/nn/module.rb', line 284
def inspect
name = self.class.name.split("::").last
if named_children.empty?
"#{name}(#{})"
else
str = String.new
str << "#{name}(\n"
named_children.each do |name, mod|
mod_str = mod.inspect
mod_str = mod_str.lines.join(" ")
str << " (#{name}): #{mod_str}\n"
end
str << ")"
end
end
|
#load_state_dict(state_dict, strict: true) ⇒ Object
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
|
# File 'lib/torch/nn/module.rb', line 136
def load_state_dict(state_dict, strict: true)
raise "strict: false not implemented yet" unless strict
missing_keys = []
unexpected_keys = []
error_msgs = []
_load = lambda do |mod, prefix = ""|
local_metadata = {}
mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
mod.named_children.each do |name, child|
_load.call(child, prefix + name + ".") unless child.nil?
end
end
_load.call(self)
if strict
if unexpected_keys.any?
error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
end
if missing_keys.any?
error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
end
end
if error_msgs.any?
raise Error, error_msgs[0]
end
nil
end
|
#modules ⇒ Object
230
231
232
|
# File 'lib/torch/nn/module.rb', line 230
def modules
named_modules.values
end
|
#named_buffers(prefix: "", recurse: false) ⇒ Object
TODO set recurse: true in 0.18.0
201
202
203
204
205
206
207
208
209
210
211
212
|
# File 'lib/torch/nn/module.rb', line 201
def named_buffers(prefix: "", recurse: false)
buffers = {}
if recurse
named_children.each do |name, mod|
buffers.merge!(mod.named_buffers(prefix: "#{prefix}#{name}.", recurse: recurse))
end
end
(@buffers || {}).each do |k, v|
buffers[[prefix, k].join] = v
end
buffers
end
|
#named_children ⇒ Object
218
219
220
221
222
223
224
225
226
227
228
|
# File 'lib/torch/nn/module.rb', line 218
def named_children
modules = {}
instance_variables.each do |name|
mod = instance_variable_get(name)
modules[name[1..-1]] = mod if mod.is_a?(Module)
end
@modules.each do |name, mod|
modules[name] = mod
end
modules
end
|
#named_modules(memo: nil, prefix: "") ⇒ Object
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
|
# File 'lib/torch/nn/module.rb', line 235
def named_modules(memo: nil, prefix: "")
ret = {}
memo ||= Set.new
unless memo.include?(self)
memo << self
ret[prefix] = self
named_children.each do |name, mod|
next unless mod.is_a?(Module)
submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
ret[m[0]] = m[1]
end
end
end
ret
end
|
#named_parameters(prefix: "", recurse: true) ⇒ Object
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
|
# File 'lib/torch/nn/module.rb', line 179
def named_parameters(prefix: "", recurse: true)
params = {}
if recurse
named_children.each do |name, mod|
params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
end
end
instance_variables.each do |name|
param = instance_variable_get(name)
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
end
@parameters.each do |name, param|
params[[prefix, name].join] = param if param
end
params
end
|
#parameters ⇒ Object
175
176
177
|
# File 'lib/torch/nn/module.rb', line 175
def parameters
named_parameters.values
end
|
#register_buffer(name, tensor, persistent: true) ⇒ Object
20
21
22
23
24
25
26
27
28
29
30
|
# File 'lib/torch/nn/module.rb', line 20
def register_buffer(name, tensor, persistent: true)
@buffers[name] = tensor
instance_variable_set("@#{name}", tensor)
if persistent
@non_persistent_buffers_set.delete(name)
else
@non_persistent_buffers_set << name
end
end
|
#register_parameter(name, param) ⇒ Object
32
33
34
35
|
# File 'lib/torch/nn/module.rb', line 32
def register_parameter(name, param)
@parameters[name] = param
end
|
#requires_grad!(requires_grad: true) ⇒ Object
264
265
266
267
268
269
|
# File 'lib/torch/nn/module.rb', line 264
def requires_grad!(requires_grad: true)
parameters.each do |p|
p.requires_grad!(requires_grad)
end
self
end
|
#respond_to?(method, include_private = false) ⇒ Boolean
324
325
326
327
|
# File 'lib/torch/nn/module.rb', line 324
def respond_to?(method, include_private = false)
name = method.to_s
named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end
|
#share_memory ⇒ Object
280
281
282
|
# File 'lib/torch/nn/module.rb', line 280
def share_memory
_apply ->(t) { t.share_memory! }
end
|
#state_dict(destination: nil, prefix: "") ⇒ Object
125
126
127
128
129
130
131
132
133
134
|
# File 'lib/torch/nn/module.rb', line 125
def state_dict(destination: nil, prefix: "")
destination ||= {}
save_to_state_dict(destination, prefix: prefix)
named_children.each do |name, mod|
next unless mod
mod.state_dict(destination: destination, prefix: prefix + name + ".")
end
destination
end
|
#to(device) ⇒ Object
113
114
115
116
117
118
119
|
# File 'lib/torch/nn/module.rb', line 113
def to(device)
convert = lambda do |t|
t.to(device)
end
_apply(convert)
end
|
#train(mode = true) ⇒ Object
252
253
254
255
256
257
258
|
# File 'lib/torch/nn/module.rb', line 252
def train(mode = true)
@training = mode
children.each do |mod|
mod.train(mode)
end
self
end
|
#type(dst_type) ⇒ Object
96
97
98
|
# File 'lib/torch/nn/module.rb', line 96
def type(dst_type)
_apply ->(t) { t.type(dst_type) }
end
|
#zero_grad ⇒ Object
271
272
273
274
275
276
277
278
|
# File 'lib/torch/nn/module.rb', line 271
def zero_grad
parameters.each do |param|
if param.grad
param.grad.detach!
param.grad.zero!
end
end
end
|