Class: NN::Dropout

Inherits:
Object
  • Object
show all
Includes:
Numo
Defined in:
lib/nn.rb

Instance Method Summary collapse

Constructor Details

#initialize(nn) ⇒ Dropout

Returns a new instance of Dropout.



377
378
379
380
# File 'lib/nn.rb', line 377

def initialize(nn)
  @nn = nn
  @mask = nil
end

Instance Method Details

#backward(dout) ⇒ Object



392
393
394
395
# File 'lib/nn.rb', line 392

def backward(dout)
  dout[@mask] = 0 if @nn.training
  dout
end

#forward(x) ⇒ Object



382
383
384
385
386
387
388
389
390
# File 'lib/nn.rb', line 382

def forward(x)
  if @nn.training
    @mask = SFloat.ones(*x.shape).rand < @nn.dropout_ratio
    x[@mask] = 0
  else
    x *= (1 - @nn.dropout_ratio)
  end
  x
end