Class: Chainer::GradientMethod

Inherits:
Optimizer show all
Defined in:
lib/chainer/gradient_method.rb

Direct Known Subclasses

Optimizers::Adam, Optimizers::MomentumSGD

Instance Attribute Summary

Attributes inherited from Optimizer

#target

Instance Method Summary collapse

Methods inherited from Optimizer

#_call_hook, #add_hook, #serialize

Constructor Details

#initializeGradientMethod

Returns a new instance of GradientMethod.



3
4
5
6
# File 'lib/chainer/gradient_method.rb', line 3

def initialize
  super()
  @hyperparam = Hyperparameter.new
end

Instance Method Details

#call_hooksObject



24
25
26
27
28
29
# File 'lib/chainer/gradient_method.rb', line 24

def call_hooks
  @hooks.values.each do |hook|
    _call_hook(hook)
    reallocate_cleared_grads
  end
end

#create_update_ruleObject

Raises:

  • (NotImplementedError)


60
61
62
# File 'lib/chainer/gradient_method.rb', line 60

def create_update_rule
  raise NotImplementedError
end

#reallocate_cleared_gradsObject



15
16
17
18
19
20
21
22
# File 'lib/chainer/gradient_method.rb', line 15

def reallocate_cleared_grads
  @target.namedparams(include_uninit: false) do |(name, param)|
    if param.grad.nil?
      xm = Chainer.get_array_module(param.data)
      param.grad = xm::NArray.[](*param.data).new_zeros
    end
  end
end

#setup(link) ⇒ Object



8
9
10
11
12
13
# File 'lib/chainer/gradient_method.rb', line 8

def setup(link)
  super(link)
  link.params do |param|
    param.update_rule = create_update_rule
  end
end

#update(lossfun = nil, *args, **kwds) ⇒ Object



31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# File 'lib/chainer/gradient_method.rb', line 31

def update(lossfun=nil, *args, **kwds)
  if lossfun
    use_cleargrads = self.methods.include?(:use_cleargrads) ? self.use_cleargrads : true
    if args.size > 0 && kwds.keys.size > 0
      loss = lossfun.(*args, **kwds)
    elsif args.size > 0
      loss = lossfun.(*args)
    elsif kwds.keys.size > 0
      loss = lossfun.(**kwds)
    end

    if use_cleargrads
      @target.cleargrads()
    else
      @target.zerograds()
    end
    loss.backward()
  end

  reallocate_cleared_grads

  call_hooks

  @t += 1
  @target.params do |param|
    param.update
  end
end