Class: Chainer::GradientMethod
Instance Attribute Summary
Attributes inherited from Optimizer
#target
Instance Method Summary
collapse
Methods inherited from Optimizer
#_call_hook, #add_hook, #serialize
Constructor Details
Returns a new instance of GradientMethod.
3
4
5
6
|
# File 'lib/chainer/gradient_method.rb', line 3
def initialize
super()
@hyperparam = Hyperparameter.new
end
|
Instance Method Details
#call_hooks ⇒ Object
24
25
26
27
28
29
|
# File 'lib/chainer/gradient_method.rb', line 24
def call_hooks
@hooks.values.each do |hook|
_call_hook(hook)
reallocate_cleared_grads
end
end
|
#create_update_rule ⇒ Object
60
61
62
|
# File 'lib/chainer/gradient_method.rb', line 60
def create_update_rule
raise NotImplementedError
end
|
#reallocate_cleared_grads ⇒ Object
15
16
17
18
19
20
21
22
|
# File 'lib/chainer/gradient_method.rb', line 15
def reallocate_cleared_grads
@target.namedparams(include_uninit: false) do |(name, param)|
if param.grad.nil?
xm = Chainer.get_array_module(param.data)
param.grad = xm::NArray.[](*param.data).new_zeros
end
end
end
|
#setup(link) ⇒ Object
8
9
10
11
12
13
|
# File 'lib/chainer/gradient_method.rb', line 8
def setup(link)
super(link)
link.params do |param|
param.update_rule = create_update_rule
end
end
|
#update(lossfun = nil, *args, **kwds) ⇒ Object
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
|
# File 'lib/chainer/gradient_method.rb', line 31
def update(lossfun=nil, *args, **kwds)
if lossfun
use_cleargrads = self.methods.include?(:use_cleargrads) ? self.use_cleargrads : true
if args.size > 0 && kwds.keys.size > 0
loss = lossfun.(*args, **kwds)
elsif args.size > 0
loss = lossfun.(*args)
elsif kwds.keys.size > 0
loss = lossfun.(**kwds)
end
if use_cleargrads
@target.cleargrads()
else
@target.zerograds()
end
loss.backward()
end
reallocate_cleared_grads
call_hooks
@t += 1
@target.params do |param|
param.update
end
end
|