Class: Chainer::Training::Extensions::ExponentialShift
- Inherits:
-
Chainer::Training::Extension
- Object
- Chainer::Training::Extension
- Chainer::Training::Extensions::ExponentialShift
- Defined in:
- lib/chainer/training/extensions/exponential_shift.rb
Overview
Trainer extension to exponentially shift an optimizer attribute.
This extension exponentially increases or decreases the specified attribute of the optimizer. The typical use case is an exponential decay of the learning rate. This extension is also called before the training loop starts by default.
Constant Summary
Constants inherited from Chainer::Training::Extension
Chainer::Training::Extension::PRIORITY_EDITOR, Chainer::Training::Extension::PRIORITY_READER, Chainer::Training::Extension::PRIORITY_WRITER
Instance Attribute Summary collapse
-
#last_value ⇒ Object
readonly
Returns the value of attribute last_value.
Attributes inherited from Chainer::Training::Extension
Instance Method Summary collapse
- #call(trainer) ⇒ Object
- #init(trainer) ⇒ Object
-
#initialize(attr, rate, init: nil, target: nil, optimizer: nil) ⇒ ExponentialShift
constructor
A new instance of ExponentialShift.
- #serialize(serializer) ⇒ Object
Methods inherited from Chainer::Training::Extension
Constructor Details
#initialize(attr, rate, init: nil, target: nil, optimizer: nil) ⇒ ExponentialShift
Returns a new instance of ExponentialShift.
17 18 19 20 21 22 23 24 25 26 |
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 17 def initialize(attr, rate, init: nil, target: nil, optimizer: nil) @attr = attr raise 'ExponentialShift does not support negative rate' if rate < 0 @rate = rate @init = init @target = target @optimizer = optimizer @t = 0 @last_value = nil end |
Instance Attribute Details
#last_value ⇒ Object (readonly)
Returns the value of attribute last_value.
10 11 12 |
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 10 def last_value @last_value end |
Instance Method Details
#call(trainer) ⇒ Object
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 38 def call(trainer) @t += 1 optimizer = get_optimizer(trainer) value = @init * (@rate ** @t) if @target if @rate > 1 if value / @target > 1 value = @target end else if value / @target < 1 value = @target end end end update_value(optimizer, value) end |
#init(trainer) ⇒ Object
28 29 30 31 32 33 34 35 36 |
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 28 def init(trainer) optimizer = get_optimizer(trainer) @init = optimizer.send(@attr) if @init.nil? if @last_value.nil? update_value(optimizer, @init) else update_value(optimizer, @last_value) end end |
#serialize(serializer) ⇒ Object
57 58 59 60 61 62 63 |
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 57 def serialize(serializer) @t = serializer.('t', @t) @last_value = serializer.('last_value', @last_value) if Chainer.array?(@last_value) @last_value = @last_value[0] end end |