Class: Rumale::Optimizer::YellowFin
- Inherits:
-
Object
- Object
- Rumale::Optimizer::YellowFin
- Includes:
- Base::BaseEstimator
- Defined in:
- lib/rumale/optimizer/yellow_fin.rb
Overview
YellowFin is a class that implements YellowFin optimizer.
Reference
-
Zhang and I. Mitliagkas, “YellowFin and the Art of Momentum Tuning,” CoRR abs/1706.03471, 2017.
-
Instance Attribute Summary
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#call(weight, gradient) ⇒ Numo::DFloat
Calculate the updated weight with adaptive momentum coefficient and learning rate.
-
#initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20) ⇒ YellowFin
constructor
Create a new optimizer with YellowFin.
Constructor Details
#initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20) ⇒ YellowFin
Create a new optimizer with YellowFin.
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
# File 'lib/rumale/optimizer/yellow_fin.rb', line 27 def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.999, window_width: 20) check_params_float(learning_rate: learning_rate, momentum: momentum, decay: decay) check_params_integer(window_width: window_width) check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay, window_width: window_width) @params = {} @params[:learning_rate] = learning_rate @params[:momentum] = momentum @params[:decay] = decay @params[:window_width] = window_width @smth_learning_rate = learning_rate @smth_momentum = momentum @grad_norms = nil @grad_norm_min = 0.0 @grad_norm_max = 0.0 @grad_mean_sqr = 0.0 @grad_mean = 0.0 @grad_var = 0.0 @grad_norm_mean = 0.0 @curve_mean = 0.0 @distance_mean = 0.0 @update = nil end |
Instance Method Details
#call(weight, gradient) ⇒ Numo::DFloat
Calculate the updated weight with adaptive momentum coefficient and learning rate.
55 56 57 58 59 60 61 62 63 64 |
# File 'lib/rumale/optimizer/yellow_fin.rb', line 55 def call(weight, gradient) @update ||= Numo::DFloat.zeros(weight.shape[0]) curvature_range(gradient) gradient_variance(gradient) distance_to_optimum(gradient) @smth_momentum = @params[:decay] * @smth_momentum + (1 - @params[:decay]) * current_momentum @smth_learning_rate = @params[:decay] * @smth_learning_rate + (1 - @params[:decay]) * current_learning_rate @update = @smth_momentum * @update - @smth_learning_rate * gradient weight + @update end |