Class: Torch::NN::Module
- Inherits:
-
Object
show all
- Defined in:
- lib/torch/nn/module.rb
Direct Known Subclasses
AvgPoolNd, Bilinear, ConvNd, CosineSimilarity, DropoutNd, Embedding, EmbeddingBag, Identity, LeakyReLU, Linear, LogSoftmax, Loss, MaxPoolNd, PReLU, PairwiseDistance, RNNBase, ReLU, Sequential, Sigmoid, Softmax, Softmax2d, Softmin, Softplus
Instance Method Summary
collapse
Constructor Details
#initialize ⇒ Module
Returns a new instance of Module.
4
5
6
7
8
9
|
# File 'lib/torch/nn/module.rb', line 4
def initialize
@training = true
@parameters = {}
@buffers = {}
@modules = {}
end
|
Dynamic Method Handling
This class handles dynamic methods through the method_missing method
#method_missing(method, *args, &block) ⇒ Object
186
187
188
189
190
191
192
193
194
195
196
197
|
# File 'lib/torch/nn/module.rb', line 186
def method_missing(method, *args, &block)
name = method.to_s
if named_parameters.key?(name)
named_parameters[name]
elsif named_buffers.key?(name)
named_buffers[name]
elsif named_modules.key?(name)
named_modules[name]
else
super
end
end
|
Instance Method Details
#_apply(fn) ⇒ Object
30
31
32
33
34
35
36
|
# File 'lib/torch/nn/module.rb', line 30
def _apply(fn)
children.each do |mod|
mod._apply(fn)
end
self
end
|
#add_module(name, mod) ⇒ Object
25
26
27
28
|
# File 'lib/torch/nn/module.rb', line 25
def add_module(name, mod)
@modules[name] = mod
end
|
#apply(fn) ⇒ Object
38
39
40
41
42
43
44
|
# File 'lib/torch/nn/module.rb', line 38
def apply(fn)
children.each do |mod|
mod.apply(fn)
end
fn.call(self)
self
end
|
#buffers ⇒ Object
108
109
110
|
# File 'lib/torch/nn/module.rb', line 108
def buffers
named_buffers.values
end
|
#call(*input) ⇒ Object
79
80
81
|
# File 'lib/torch/nn/module.rb', line 79
def call(*input)
forward(*input)
end
|
#children ⇒ Object
116
117
118
|
# File 'lib/torch/nn/module.rb', line 116
def children
named_children.values
end
|
#cpu ⇒ Object
50
51
52
|
# File 'lib/torch/nn/module.rb', line 50
def cpu
_apply ->(t) { t.cpu }
end
|
#cuda(device: nil) ⇒ Object
46
47
48
|
# File 'lib/torch/nn/module.rb', line 46
def cuda(device: nil)
_apply ->(t) { t.cuda(device) }
end
|
#double ⇒ Object
62
63
64
|
# File 'lib/torch/nn/module.rb', line 62
def double
_apply ->(t) { t.floating_point? ? t.double : t }
end
|
#eval ⇒ Object
148
149
150
|
# File 'lib/torch/nn/module.rb', line 148
def eval
train(false)
end
|
#float ⇒ Object
58
59
60
|
# File 'lib/torch/nn/module.rb', line 58
def float
_apply ->(t) { t.floating_point? ? t.float : t }
end
|
#forward ⇒ Object
11
12
13
|
# File 'lib/torch/nn/module.rb', line 11
def forward
raise NotImplementedError
end
|
#half ⇒ Object
66
67
68
|
# File 'lib/torch/nn/module.rb', line 66
def half
_apply ->(t) { t.floating_point? ? t.half : t }
end
|
#inspect ⇒ Object
172
173
174
175
176
177
178
179
180
181
182
183
184
|
# File 'lib/torch/nn/module.rb', line 172
def inspect
name = self.class.name.split("::").last
if children.empty?
"#{name}(#{extra_inspect})"
else
str = String.new
str << "#{name}(\n"
children.each do |name, mod|
str << " (#{name}): #{mod.inspect}\n"
end
str << ")"
end
end
|
#modules ⇒ Object
132
133
134
|
# File 'lib/torch/nn/module.rb', line 132
def modules
named_modules.values
end
|
#named_buffers ⇒ Object
112
113
114
|
# File 'lib/torch/nn/module.rb', line 112
def named_buffers
@buffers || {}
end
|
#named_children ⇒ Object
120
121
122
123
124
125
126
127
128
129
130
|
# File 'lib/torch/nn/module.rb', line 120
def named_children
modules = {}
instance_variables.each do |name|
mod = instance_variable_get(name)
modules[name[1..-1]] = mod if mod.is_a?(Module)
end
@modules.each do |name, mod|
modules[name] = mod
end
modules
end
|
#named_modules ⇒ Object
136
137
138
|
# File 'lib/torch/nn/module.rb', line 136
def named_modules
{"" => self}.merge(named_children)
end
|
#named_parameters(prefix: "", recurse: true) ⇒ Object
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
|
# File 'lib/torch/nn/module.rb', line 91
def named_parameters(prefix: "", recurse: true)
params = {}
if recurse
named_children.each do |name, mod|
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
end
end
instance_variables.each do |name|
param = instance_variable_get(name)
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
end
@parameters.each do |name, param|
params[[prefix, name].join] = param
end
params
end
|
#parameters ⇒ Object
87
88
89
|
# File 'lib/torch/nn/module.rb', line 87
def parameters
named_parameters.values
end
|
#register_buffer(name, tensor) ⇒ Object
15
16
17
18
|
# File 'lib/torch/nn/module.rb', line 15
def register_buffer(name, tensor)
@buffers[name] = tensor
end
|
#register_parameter(name, param) ⇒ Object
20
21
22
23
|
# File 'lib/torch/nn/module.rb', line 20
def register_parameter(name, param)
@parameters[name] = param
end
|
#requires_grad!(requires_grad: true) ⇒ Object
152
153
154
155
156
157
|
# File 'lib/torch/nn/module.rb', line 152
def requires_grad!(requires_grad: true)
parameters.each do |p|
p.requires_grad!(requires_grad)
end
self
end
|
#respond_to?(method, include_private = false) ⇒ Boolean
199
200
201
202
|
# File 'lib/torch/nn/module.rb', line 199
def respond_to?(method, include_private = false)
name = method.to_s
named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end
|
#share_memory ⇒ Object
168
169
170
|
# File 'lib/torch/nn/module.rb', line 168
def share_memory
_apply ->(t) { t.share_memory! }
end
|
#state_dict ⇒ Object
83
84
85
|
# File 'lib/torch/nn/module.rb', line 83
def state_dict
raise NotImplementedYet
end
|
#to(device) ⇒ Object
71
72
73
74
75
76
77
|
# File 'lib/torch/nn/module.rb', line 71
def to(device)
convert = lambda do |t|
t.to(device)
end
_apply(convert)
end
|
#train(mode = true) ⇒ Object
140
141
142
143
144
145
146
|
# File 'lib/torch/nn/module.rb', line 140
def train(mode = true)
@training = mode
children.each do |mod|
mod.train(mode)
end
self
end
|
#type(dst_type) ⇒ Object
54
55
56
|
# File 'lib/torch/nn/module.rb', line 54
def type(dst_type)
_apply ->(t) { t.type(dst_type) }
end
|
#zero_grad ⇒ Object
159
160
161
162
163
164
165
166
|
# File 'lib/torch/nn/module.rb', line 159
def zero_grad
parameters.each do |param|
if param.grad
param.grad.detach!
param.grad.zero!
end
end
end
|