Class: Chainer::Training::Extensions::ExponentialShift

Inherits:
Chainer::Training::Extension show all
Defined in:
lib/chainer/training/extensions/exponential_shift.rb

Overview

Trainer extension to exponentially shift an optimizer attribute.

This extension exponentially increases or decreases the specified attribute of the optimizer. The typical use case is an exponential decay of the learning rate. This extension is also called before the training loop starts by default.

Constant Summary

Constants inherited from Chainer::Training::Extension

Chainer::Training::Extension::PRIORITY_EDITOR, Chainer::Training::Extension::PRIORITY_READER, Chainer::Training::Extension::PRIORITY_WRITER

Instance Attribute Summary collapse

Attributes inherited from Chainer::Training::Extension

#name, #priority, #trigger

Instance Method Summary collapse

Methods inherited from Chainer::Training::Extension

#default_name

Constructor Details

#initialize(attr, rate, init: nil, target: nil, optimizer: nil) ⇒ ExponentialShift

Returns a new instance of ExponentialShift.

Parameters:

  • attr (string)

    Name of the attribute to shift

  • rate (float)

    Rate of the exponential shift.

  • init (float) (defaults to: nil)

    Initial value of the attribute.

  • target (float) (defaults to: nil)

    Target value of the attribute.

  • optimizer (Chainer::Optimizer) (defaults to: nil)

    Target optimizer to adjust the attribute.



17
18
19
20
21
22
23
24
25
26
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 17

def initialize(attr, rate, init: nil, target: nil, optimizer: nil)
  @attr = attr
  raise 'ExponentialShift does not support negative rate' if rate < 0
  @rate = rate
  @init = init
  @target = target
  @optimizer = optimizer
  @t = 0
  @last_value = nil
end

Instance Attribute Details

#last_valueObject (readonly)

Returns the value of attribute last_value.



10
11
12
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 10

def last_value
  @last_value
end

Instance Method Details

#call(trainer) ⇒ Object



38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 38

def call(trainer)
  @t += 1

  optimizer = get_optimizer(trainer)
  value = @init * (@rate ** @t)
  if @target
    if @rate > 1
      if value / @target > 1
        value = @target
      end
    else
      if value / @target < 1
        value = @target
      end
    end
  end
  update_value(optimizer, value)
end

#init(trainer) ⇒ Object



28
29
30
31
32
33
34
35
36
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 28

def init(trainer)
  optimizer = get_optimizer(trainer)
  @init = optimizer.send(@attr) if @init.nil?
  if @last_value.nil?
    update_value(optimizer, @init)
  else
    update_value(optimizer, @last_value)
  end
end

#serialize(serializer) ⇒ Object



57
58
59
60
61
62
63
# File 'lib/chainer/training/extensions/exponential_shift.rb', line 57

def serialize(serializer)
  @t = serializer.('t', @t)
  @last_value = serializer.('last_value', @last_value)
  if Chainer.array?(@last_value)
    @last_value = @last_value[0]
  end
end