Class: Rumale::Optimizer::Adam
- Inherits:
-
Object
- Object
- Rumale::Optimizer::Adam
- Includes:
- Base::BaseEstimator
- Defined in:
- lib/rumale/optimizer/adam.rb
Overview
Adam is a class that implements Adam optimizer.
Reference
-
D P. Kingma and J. Ba, “Adam: A Method for Stochastic Optimization,” Proc. ICLR’15, 2015.
Instance Attribute Summary
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#call(weight, gradient) ⇒ Numo::DFloat
Calculate the updated weight with Nadam adaptive learning rate.
-
#initialize(learning_rate: 0.001, decay1: 0.9, decay2: 0.999) ⇒ Adam
constructor
Create a new optimizer with Adam.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
Constructor Details
#initialize(learning_rate: 0.001, decay1: 0.9, decay2: 0.999) ⇒ Adam
Create a new optimizer with Adam
26 27 28 29 30 31 32 33 34 35 36 |
# File 'lib/rumale/optimizer/adam.rb', line 26 def initialize(learning_rate: 0.001, decay1: 0.9, decay2: 0.999) check_params_float(learning_rate: learning_rate, decay1: decay1, decay2: decay2) check_params_positive(learning_rate: learning_rate, decay1: decay1, decay2: decay2) @params = {} @params[:learning_rate] = learning_rate @params[:decay1] = decay1 @params[:decay2] = decay2 @fst_moment = nil @sec_moment = nil @iter = 0 end |
Instance Method Details
#call(weight, gradient) ⇒ Numo::DFloat
Calculate the updated weight with Nadam adaptive learning rate.
43 44 45 46 47 48 49 50 51 52 53 54 55 |
# File 'lib/rumale/optimizer/adam.rb', line 43 def call(weight, gradient) @fst_moment ||= Numo::DFloat.zeros(weight.shape[0]) @sec_moment ||= Numo::DFloat.zeros(weight.shape[0]) @iter += 1 @fst_moment = @params[:decay1] * @fst_moment + (1.0 - @params[:decay1]) * gradient @sec_moment = @params[:decay2] * @sec_moment + (1.0 - @params[:decay2]) * gradient**2 nm_fst_moment = @fst_moment / (1.0 - @params[:decay1]**@iter) nm_sec_moment = @sec_moment / (1.0 - @params[:decay2]**@iter) weight - @params[:learning_rate] * nm_fst_moment / (nm_sec_moment**0.5 + 1e-8) end |
#marshal_dump ⇒ Hash
Dump marshal data.
59 60 61 62 63 64 |
# File 'lib/rumale/optimizer/adam.rb', line 59 def marshal_dump { params: @params, fst_moment: @fst_moment, sec_moment: @sec_moment, iter: @iter } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
68 69 70 71 72 73 74 |
# File 'lib/rumale/optimizer/adam.rb', line 68 def marshal_load(obj) @params = obj[:params] @fst_moment = obj[:fst_moment] @sec_moment = obj[:sec_moment] @iter = obj[:iter] nil end |