Class: Torch::Optim::LRScheduler::MultiplicativeLR
- Inherits:
-
LRScheduler
- Object
- LRScheduler
- Torch::Optim::LRScheduler::MultiplicativeLR
- Defined in:
- lib/torch/optim/lr_scheduler/multiplicative_lr.rb
Instance Method Summary collapse
- #get_lr ⇒ Object
-
#initialize(optimizer, lr_lambda, last_epoch: -1)) ⇒ MultiplicativeLR
constructor
A new instance of MultiplicativeLR.
Methods inherited from LRScheduler
Constructor Details
#initialize(optimizer, lr_lambda, last_epoch: -1)) ⇒ MultiplicativeLR
Returns a new instance of MultiplicativeLR.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
# File 'lib/torch/optim/lr_scheduler/multiplicative_lr.rb', line 5 def initialize(optimizer, lr_lambda, last_epoch: -1) @optimizer = optimizer if !lr_lambda.is_a?(Array) @lr_lambdas = [lr_lambda] * optimizer.param_groups.length else if lr_lambda.length != optimizer.param_groups.length raise ArgumentError, "Expected #{optimizer.param_groups.length}, but got #{lr_lambda.length}" end @lr_lambdas = lr_lambda end @last_epoch = last_epoch super(optimizer, last_epoch) end |
Instance Method Details
#get_lr ⇒ Object
20 21 22 23 24 25 26 27 28 |
# File 'lib/torch/optim/lr_scheduler/multiplicative_lr.rb', line 20 def get_lr if @last_epoch > 0 @lr_lambdas.zip(@optimizer.param_groups).map do |lmbda, group| group[:lr] * lmbda.call(@last_epoch) end else @base_lrs end end |