Class: Chainer::Optimizers::AdamRule

Inherits:
UpdateRule show all
Defined in:
lib/chainer/optimizers/adam.rb

Instance Attribute Summary

Attributes inherited from UpdateRule

#state

Instance Method Summary collapse

Methods inherited from UpdateRule

#serialize, #update, #update_core_cpu, #update_core_gpu

Constructor Details

#initialize(parent_hyperparam: nil, alpha: nil, beta1: nil, beta2: nil, eps: nil) ⇒ AdamRule

Returns a new instance of AdamRule.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
# File 'lib/chainer/optimizers/adam.rb', line 4

def initialize(parent_hyperparam: nil, alpha: nil, beta1: nil, beta2: nil, eps: nil)
  hyperparam = Hyperparameter.new
  hyperparam.instance_variable_set('@alpha', 0.001)
  hyperparam.instance_variable_set('@beta1', 0.9)
  hyperparam.instance_variable_set('@beta2', 0.999)
  hyperparam.instance_variable_set('@eps', 1e-8)

  super(parent_hyperparam: parent_hyperparam || hyperparam)

  @hyperparam.instance_variable_set('@alpha', alpha) if alpha
  @hyperparam.instance_variable_set('@beta1', beta1) if beta1
  @hyperparam.instance_variable_set('@beta2', beta2) if beta2
  @hyperparam.instance_variable_set('@eps', eps) if eps
end

Instance Method Details

#init_state(param) ⇒ Object



19
20
21
22
# File 'lib/chainer/optimizers/adam.rb', line 19

def init_state(param)
  @state[:m] = param.data.new_zeros
  @state[:v] = param.data.new_zeros
end

#lrObject



36
37
38
39
40
# File 'lib/chainer/optimizers/adam.rb', line 36

def lr
  fix1 = 1.0 - @hyperparam.beta1 ** @t
  fix2 = 1.0 - @hyperparam.beta2 ** @t
  @hyperparam.alpha * Math.sqrt(fix2) / fix1
end

#update_core(param) ⇒ Object



24
25
26
27
28
29
30
31
32
33
34
# File 'lib/chainer/optimizers/adam.rb', line 24

def update_core(param)
  grad = param.grad
  return if grad.nil?

  hp = @hyperparam

  @state[:m] += (1 - hp.beta1) * (grad - @state[:m])
  @state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
  xm = Chainer.get_array_module(grad)
  param.data -= lr * @state[:m] / (xm::NMath.sqrt(@state[:v]) + hp.eps)
end