Class: Torch::Optim::LRScheduler::LambdaLR

Inherits:
Torch::Optim::LRScheduler show all
Defined in:
lib/torch/optim/lr_scheduler/lambda_lr.rb

Instance Method Summary collapse

Constructor Details

#initialize(optimizer, lr_lambda, last_epoch: -1)) ⇒ LambdaLR

Returns a new instance of LambdaLR.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
# File 'lib/torch/optim/lr_scheduler/lambda_lr.rb', line 5

def initialize(optimizer, lr_lambda, last_epoch: -1)
  @optimizer = optimizer

  if !lr_lambda.is_a?(Array)
    @lr_lambdas = [lr_lambda] * optimizer.param_groups.length
  else
    if lr_lambda.length != optimizer.param_groups.length
      raise ArgumentError, "Expected #{optimizer.param_groups.length}, but got #{lr_lambda.length}"
    end
    @lr_lambdas = lr_lambda
  end
  @last_epoch = last_epoch
  super(optimizer, last_epoch)
end

Instance Method Details

#get_lrObject



20
21
22
23
24
# File 'lib/torch/optim/lr_scheduler/lambda_lr.rb', line 20

def get_lr
  @lr_lambdas.zip(@base_lrs).map do |lmbda, base_lr|
    base_lr * lmbda.call(@last_epoch)
  end
end