Class: DNN::Losses::Loss
- Inherits:
-
Object
show all
- Defined in:
- lib/dnn/core/losses.rb
Class Method Summary
collapse
Instance Method Summary
collapse
Class Method Details
.call(y, t, *args, **kwargs) ⇒ Object
5
6
7
|
# File 'lib/dnn/core/losses.rb', line 5
def self.call(y, t, *args, **kwargs)
new(*args, **kwargs).(y, t)
end
|
.from_hash(hash) ⇒ Object
9
10
11
12
13
14
15
16
|
# File 'lib/dnn/core/losses.rb', line 9
def self.from_hash(hash)
return nil unless hash
loss_class = DNN.const_get(hash[:class])
loss = loss_class.allocate
raise DNNError, "#{loss.class} is not an instance of #{self} class." unless loss.is_a?(self)
loss.load_hash(hash)
loss
end
|
Instance Method Details
#call(y, t) ⇒ Object
18
19
20
|
# File 'lib/dnn/core/losses.rb', line 18
def call(y, t)
forward(y, t)
end
|
#clean ⇒ Object
55
56
57
58
59
60
61
|
# File 'lib/dnn/core/losses.rb', line 55
def clean
hash = to_hash
instance_variables.each do |ivar|
instance_variable_set(ivar, nil)
end
load_hash(hash)
end
|
#forward(y, t) ⇒ Object
32
33
34
|
# File 'lib/dnn/core/losses.rb', line 32
def forward(y, t)
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
end
|
#load_hash(hash) ⇒ Object
51
52
53
|
# File 'lib/dnn/core/losses.rb', line 51
def load_hash(hash)
initialize
end
|
#loss(y, t, layers: nil, loss_weight: nil) ⇒ Object
22
23
24
25
26
27
28
29
30
|
# File 'lib/dnn/core/losses.rb', line 22
def loss(y, t, layers: nil, loss_weight: nil)
unless y.shape == t.shape
raise DNNShapeError, "The shape of y does not match the t shape. y shape is #{y.shape}, but t shape is #{t.shape}."
end
loss = call(y, t)
loss *= loss_weight if loss_weight
loss = regularizers_forward(loss, layers) if layers
loss
end
|
#regularizers_forward(loss, layers) ⇒ Object
36
37
38
39
40
41
42
43
|
# File 'lib/dnn/core/losses.rb', line 36
def regularizers_forward(loss, layers)
regularizers = layers.select { |layer| layer.respond_to?(:regularizers) }
.map(&:regularizers).flatten
regularizers.each do |regularizer|
loss = regularizer.forward(loss)
end
loss
end
|
#to_hash(merge_hash = nil) ⇒ Object
45
46
47
48
49
|
# File 'lib/dnn/core/losses.rb', line 45
def to_hash(merge_hash = nil)
hash = { class: self.class.name }
hash.merge!(merge_hash) if merge_hash
hash
end
|