Class: Chainer::Optimizers::MomentumSGDRule
- Inherits:
-
UpdateRule
- Object
- UpdateRule
- Chainer::Optimizers::MomentumSGDRule
- Defined in:
- lib/chainer/optimizers/momentum_sgd.rb
Overview
Update rule for the classical momentum SGD
Instance Attribute Summary
Attributes inherited from UpdateRule
Instance Method Summary collapse
- #init_state(param) ⇒ Object
-
#initialize(parent_hyperparam: nil, lr: nil, mementum: nil) ⇒ MomentumSGDRule
constructor
A new instance of MomentumSGDRule.
- #update_core(param) ⇒ Object
Methods inherited from UpdateRule
#serialize, #update, #update_core_cpu, #update_core_gpu
Constructor Details
#initialize(parent_hyperparam: nil, lr: nil, mementum: nil) ⇒ MomentumSGDRule
Returns a new instance of MomentumSGDRule.
5 6 7 8 9 10 11 12 13 14 |
# File 'lib/chainer/optimizers/momentum_sgd.rb', line 5 def initialize(parent_hyperparam: nil, lr: nil, mementum: nil) hyperparam = Hyperparameter.new hyperparam.instance_variable_set('@lr', 0.01) hyperparam.instance_variable_set('@momentum', 0.9) super(parent_hyperparam: parent_hyperparam || hyperparam) @hyperparam.instance_variable_set('@lr', lr) if lr @hyperparam.instance_variable_set('@mementum', mementum) if mementum end |
Instance Method Details
#init_state(param) ⇒ Object
16 17 18 |
# File 'lib/chainer/optimizers/momentum_sgd.rb', line 16 def init_state(param) @state[:v] = param.data.new_zeros end |
#update_core(param) ⇒ Object
20 21 22 23 24 25 26 27 28 |
# File 'lib/chainer/optimizers/momentum_sgd.rb', line 20 def update_core(param) grad = param.grad return if grad.nil? v = @state[:v] v *= @hyperparam.momentum v -= @hyperparam.lr * grad param.data += v end |