Class: DNN::TwoInputLink

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/link.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(prev1 = nil, prev2 = nil, layer_node = nil) ⇒ TwoInputLink



30
31
32
33
34
35
36
# File 'lib/dnn/core/link.rb', line 30

def initialize(prev1 = nil, prev2 = nil, layer_node = nil)
  @prev1 = prev1
  @prev2 = prev2
  @layer_node = layer_node
  @next = nil
  @hold = []
end

Instance Attribute Details

#layer_nodeObject

Returns the value of attribute layer_node.



28
29
30
# File 'lib/dnn/core/link.rb', line 28

def layer_node
  @layer_node
end

#nextObject

Returns the value of attribute next.



27
28
29
# File 'lib/dnn/core/link.rb', line 27

def next
  @next
end

#prev1Object

Returns the value of attribute prev1.



25
26
27
# File 'lib/dnn/core/link.rb', line 25

def prev1
  @prev1
end

#prev2Object

Returns the value of attribute prev2.



26
27
28
# File 'lib/dnn/core/link.rb', line 26

def prev2
  @prev2
end

Instance Method Details

#backward(dy = ) ⇒ Object



46
47
48
49
50
51
52
53
54
55
# File 'lib/dnn/core/link.rb', line 46

def backward(dy = Xumo::SFloat[1])
  dys = @layer_node.backward_node(dy)
  if dys.is_a?(Array)
    dy1, dy2 = *dys
  else
    dy1 = dys
  end
  @prev1&.backward(dy1)
  @prev2&.backward(dy2) if dy2
end

#forward(x) ⇒ Object



38
39
40
41
42
43
44
# File 'lib/dnn/core/link.rb', line 38

def forward(x)
  @hold << x
  return if @hold.length < 2
  x = @layer_node.(*@hold)
  @hold = []
  @next ? @next.forward(x) : x
end