Class: Chainer::Optimizers::MomentumSGDRule

Inherits:
UpdateRule
  • Object
show all
Defined in:
lib/chainer/optimizers/momentum_sgd.rb

Overview

Update rule for the classical momentum SGD

Instance Attribute Summary

Attributes inherited from UpdateRule

#state

Instance Method Summary collapse

Methods inherited from UpdateRule

#serialize, #update, #update_core_cpu, #update_core_gpu

Constructor Details

#initialize(parent_hyperparam: nil, lr: nil, mementum: nil) ⇒ MomentumSGDRule

Returns a new instance of MomentumSGDRule.



5
6
7
8
9
10
11
12
13
14
# File 'lib/chainer/optimizers/momentum_sgd.rb', line 5

def initialize(parent_hyperparam: nil, lr: nil, mementum: nil)
  hyperparam = Hyperparameter.new
  hyperparam.instance_variable_set('@lr', 0.01)
  hyperparam.instance_variable_set('@momentum', 0.9)

  super(parent_hyperparam: parent_hyperparam || hyperparam)
  
  @hyperparam.instance_variable_set('@lr', lr) if lr
  @hyperparam.instance_variable_set('@mementum', mementum) if mementum
end

Instance Method Details

#init_state(param) ⇒ Object



16
17
18
# File 'lib/chainer/optimizers/momentum_sgd.rb', line 16

def init_state(param)
  @state[:v] = param.data.new_zeros
end

#update_core(param) ⇒ Object



20
21
22
23
24
25
26
27
28
# File 'lib/chainer/optimizers/momentum_sgd.rb', line 20

def update_core(param)
  grad = param.grad
  return if grad.nil?
    
  v = @state[:v]
  v *= @hyperparam.momentum
  v -= @hyperparam.lr * grad
  param.data += v
end