Class: Rumale::Optimizer::RMSProp
- Inherits:
-
Object
- Object
- Rumale::Optimizer::RMSProp
- Includes:
- Base::BaseEstimator
- Defined in:
- lib/rumale/optimizer/rmsprop.rb
Overview
RMSProp is a class that implements RMSProp optimizer.
Reference
-
Sutskever, J. Martens, G. Dahl, and G. Hinton, “On the importance of initialization and momentum in deep learning,” Proc. ICML’ 13, pp. 1139–1147, 2013.
-
-
Hinton, N. Srivastava, and K. Swersky, “Lecture 6e rmsprop,” Neural Networks for Machine Learning, 2012.
-
Instance Attribute Summary
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#call(weight, gradient) ⇒ Numo::DFloat
Calculate the updated weight with RMSProp adaptive learning rate.
-
#initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.9) ⇒ RMSProp
constructor
Create a new optimizer with RMSProp.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
Constructor Details
#initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.9) ⇒ RMSProp
Create a new optimizer with RMSProp.
27 28 29 30 31 32 33 34 35 36 |
# File 'lib/rumale/optimizer/rmsprop.rb', line 27 def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.9) check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay) check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay) @params = {} @params[:learning_rate] = learning_rate @params[:momentum] = momentum @params[:decay] = decay @moment = nil @update = nil end |
Instance Method Details
#call(weight, gradient) ⇒ Numo::DFloat
Calculate the updated weight with RMSProp adaptive learning rate.
43 44 45 46 47 48 49 |
# File 'lib/rumale/optimizer/rmsprop.rb', line 43 def call(weight, gradient) @moment ||= Numo::DFloat.zeros(weight.shape[0]) @update ||= Numo::DFloat.zeros(weight.shape[0]) @moment = @params[:decay] * @moment + (1.0 - @params[:decay]) * gradient**2 @update = @params[:momentum] * @update - (@params[:learning_rate] / (@moment**0.5 + 1.0e-8)) * gradient weight + @update end |
#marshal_dump ⇒ Hash
Dump marshal data.
53 54 55 56 57 |
# File 'lib/rumale/optimizer/rmsprop.rb', line 53 def marshal_dump { params: @params, moment: @moment, update: @update } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
61 62 63 64 65 66 |
# File 'lib/rumale/optimizer/rmsprop.rb', line 61 def marshal_load(obj) @params = obj[:params] @moment = obj[:moment] @update = obj[:update] nil end |