Class: Torch::Optim::Adamax

Inherits:
Optimizer show all
Defined in:
lib/torch/optim/adamax.rb

Instance Attribute Summary

Attributes inherited from Optimizer

#param_groups

Instance Method Summary collapse

Methods inherited from Optimizer

#add_param_group, #load_state_dict, #state_dict, #zero_grad

Constructor Details

#initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0) ⇒ Adamax

Returns a new instance of Adamax.

Raises:

  • (ArgumentError)
[View source]

5
6
7
8
9
10
11
12
13
14
# File 'lib/torch/optim/adamax.rb', line 5

def initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0)
  raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
  raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
  raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
  raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
  raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0

  defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay}
  super(params, defaults)
end

Instance Method Details

#step(closure = nil) ⇒ Object

[View source]

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# File 'lib/torch/optim/adamax.rb', line 16

def step(closure = nil)
  loss = nil
  if closure
    loss = closure.call
  end

  @param_groups.each do |group|
    group[:params].each do |p|
      next unless p.grad
      grad = p.grad.data
      if grad.sparse?
        raise Error, "Adamax does not support sparse gradients, please consider SparseAdam instead"
      end
      state = @state[p]

      # State initialization
      if state.size == 0
        state[:step] = 0
        state[:exp_avg] = Torch.zeros_like(p.data)
        state[:exp_inf] = Torch.zeros_like(p.data)
      end

      exp_avg, exp_inf = state[:exp_avg], state[:exp_inf]
      beta1, beta2 = group[:betas]
      eps = group[:eps]

      state[:step] += 1

      if group[:weight_decay] != 0
        grad = grad.add(p.data, alpha: group[:weight_decay])
      end

      # Update biased first moment estimate.
      exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
      # Update the exponentially weighted infinity norm.
      norm_buf = Torch.cat([
          exp_inf.mul!(beta2).unsqueeze(0),
          grad.abs.add!(eps).unsqueeze!(0)
      ], 0)
      Torch.max(norm_buf, 0, keepdim: false, out: [exp_inf, exp_inf.new.long])

      bias_correction = 1 - beta1 ** state[:step]
      clr = group[:lr] / bias_correction

      p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
    end
  end

  loss
end