Class: Chainer::UpdateRule
- Inherits:
-
Object
- Object
- Chainer::UpdateRule
show all
- Defined in:
- lib/chainer/optimizer.rb
Instance Attribute Summary collapse
Instance Method Summary
collapse
Constructor Details
#initialize(parent_hyperparam:) ⇒ UpdateRule
Returns a new instance of UpdateRule.
57
58
59
60
61
62
63
|
# File 'lib/chainer/optimizer.rb', line 57
def initialize(parent_hyperparam:)
@hooks = {}
@state = nil
@enabled = true
@hyperparam = Chainer::Hyperparameter.new(parent: parent_hyperparam)
@t = 0
end
|
Instance Attribute Details
#state ⇒ Object
Returns the value of attribute state.
55
56
57
|
# File 'lib/chainer/optimizer.rb', line 55
def state
@state
end
|
Instance Method Details
#init_state(param) ⇒ Object
95
96
97
|
# File 'lib/chainer/optimizer.rb', line 95
def init_state(param)
raise NotImplementedError
end
|
#serialize(serializer) ⇒ Object
Serializes the update rule state. Be careful that this method only saves/loads the state of the update rule. The parameters of the target link is not saved/loaded by this method, and so you need to serialize the target link separately if you want to fully recover the training state including parameters.
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
|
# File 'lib/chainer/optimizer.rb', line 107
def serialize(serializer)
if @state.nil?
if serializer.is_a?(Chainer::Deserializer)
@state = {}
self_copy = self.dup
xm = Chainer::Device.default.xm
arr = xm::SFloat.new(1)
self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
@state.keys.each do |key|
@state[key] = serializer.(key.to_s, nil)
end
end
else
@state.each do |key, val|
@state[key] = serializer.(key.to_s, val)
end
end
end
|
#update(param) ⇒ Object
65
66
67
68
69
70
71
72
73
74
75
76
|
# File 'lib/chainer/optimizer.rb', line 65
def update(param)
return unless @enabled
@t += 1
unless param.data.nil?
prepare(param)
end
@hooks.values.each do |hook|
hook.call(param)
end
update_core(param)
end
|
#update_core(param) ⇒ Object
78
79
80
81
82
83
84
85
|
# File 'lib/chainer/optimizer.rb', line 78
def update_core(param)
xm = Chainer.get_array_module(param)
if xm == Cumo
update_core_gpu(param)
else
update_core_cpu(param)
end
end
|
#update_core_cpu ⇒ Object
87
88
89
|
# File 'lib/chainer/optimizer.rb', line 87
def update_core_cpu
raise NotImplementedError
end
|
#update_core_gpu ⇒ Object
91
92
93
|
# File 'lib/chainer/optimizer.rb', line 91
def update_core_gpu
raise NotImplementedError
end
|