Class: Rumale::Optimizer::SGD
- Inherits:
-
Object
- Object
- Rumale::Optimizer::SGD
- Includes:
- Base::BaseEstimator
- Defined in:
- lib/rumale/optimizer/sgd.rb
Overview
SGD is a class that implements SGD optimizer.
Instance Attribute Summary
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#call(weight, gradient) ⇒ Numo::DFloat
Calculate the updated weight with SGD.
-
#initialize(learning_rate: 0.01, momentum: 0.0, decay: 0.0) ⇒ SGD
constructor
Create a new optimizer with SGD.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
Constructor Details
#initialize(learning_rate: 0.01, momentum: 0.0, decay: 0.0) ⇒ SGD
Create a new optimizer with SGD.
23 24 25 26 27 28 29 30 31 32 |
# File 'lib/rumale/optimizer/sgd.rb', line 23 def initialize(learning_rate: 0.01, momentum: 0.0, decay: 0.0) check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay) check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay) @params = {} @params[:learning_rate] = learning_rate @params[:momentum] = momentum @params[:decay] = decay @iter = 0 @update = nil end |
Instance Method Details
#call(weight, gradient) ⇒ Numo::DFloat
Calculate the updated weight with SGD.
39 40 41 42 43 44 45 |
# File 'lib/rumale/optimizer/sgd.rb', line 39 def call(weight, gradient) @update ||= Numo::DFloat.zeros(weight.shape[0]) current_learning_rate = @params[:learning_rate] / (1.0 + @params[:decay] * @iter) @iter += 1 @update = @params[:momentum] * @update - current_learning_rate * gradient weight + @update end |
#marshal_dump ⇒ Hash
Dump marshal data.
49 50 51 52 53 |
# File 'lib/rumale/optimizer/sgd.rb', line 49 def marshal_dump { params: @params, iter: @iter, update: @update } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
57 58 59 60 61 62 |
# File 'lib/rumale/optimizer/sgd.rb', line 57 def marshal_load(obj) @params = obj[:params] @iter = obj[:iter] @update = obj[:update] nil end |