Class: NN::Affine

Inherits:
Object
  • Object
show all
Includes:
Numo
Defined in:
lib/nn.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(nn, index) ⇒ Affine

Returns a new instance of Affine.


279
280
281
282
283
284
# File 'lib/nn.rb', line 279

def initialize(nn, index)
  @nn = nn
  @index = index
  @d_weight = nil
  @d_bias = nil
end

Instance Attribute Details

#d_biasObject (readonly)

Returns the value of attribute d_bias.


277
278
279
# File 'lib/nn.rb', line 277

def d_bias
  @d_bias
end

#d_weightObject (readonly)

Returns the value of attribute d_weight.


276
277
278
# File 'lib/nn.rb', line 276

def d_weight
  @d_weight
end

Instance Method Details

#backward(dout) ⇒ Object


291
292
293
294
295
296
297
298
299
300
# File 'lib/nn.rb', line 291

def backward(dout)
  x = @x.reshape(*@x.shape, 1)
  @d_weight = x.dot(dout.reshape(dout.shape[0], 1, dout.shape[1]))
  if @nn.weight_decay > 0
    dridge = @nn.weight_decay * @nn.weights[@index]
    @d_weight += dridge
  end
  @d_bias = dout
  dout.dot(@nn.weights[@index].transpose)
end

#forward(x) ⇒ Object


286
287
288
289
# File 'lib/nn.rb', line 286

def forward(x)
  @x = x
  @x.dot(@nn.weights[@index]) + @nn.biases[@index]
end