Class: Torch::Optim::Adamax
- Defined in:
- lib/torch/optim/adamax.rb
Instance Attribute Summary
Attributes inherited from Optimizer
Instance Method Summary collapse
-
#initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0) ⇒ Adamax
constructor
A new instance of Adamax.
- #step(closure = nil) ⇒ Object
Methods inherited from Optimizer
#add_param_group, #load_state_dict, #state_dict, #zero_grad
Constructor Details
permalink #initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0) ⇒ Adamax
Returns a new instance of Adamax.
5 6 7 8 9 10 11 12 13 14 |
# File 'lib/torch/optim/adamax.rb', line 5 def initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0) raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0 raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0 raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1 raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1 raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0 defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay} super(params, defaults) end |
Instance Method Details
permalink #step(closure = nil) ⇒ Object
[View source]
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
# File 'lib/torch/optim/adamax.rb', line 16 def step(closure = nil) loss = nil if closure loss = closure.call end @param_groups.each do |group| group[:params].each do |p| next unless p.grad grad = p.grad.data if grad.sparse? raise Error, "Adamax does not support sparse gradients, please consider SparseAdam instead" end state = @state[p] # State initialization if state.size == 0 state[:step] = 0 state[:exp_avg] = Torch.zeros_like(p.data) state[:exp_inf] = Torch.zeros_like(p.data) end exp_avg, exp_inf = state[:exp_avg], state[:exp_inf] beta1, beta2 = group[:betas] eps = group[:eps] state[:step] += 1 if group[:weight_decay] != 0 grad = grad.add(p.data, alpha: group[:weight_decay]) end # Update biased first moment estimate. exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1) # Update the exponentially weighted infinity norm. norm_buf = Torch.cat([ exp_inf.mul!(beta2).unsqueeze(0), grad.abs.add!(eps).unsqueeze!(0) ], 0) Torch.max(norm_buf, 0, keepdim: false, out: [exp_inf, exp_inf.new.long]) bias_correction = 1 - beta1 ** state[:step] clr = group[:lr] / bias_correction p.data.addcdiv!(exp_avg, exp_inf, value: -clr) end end loss end |