Class: Chainer::Training::Triggers::IntervalTrigger

Inherits:
Object
  • Object
show all
Defined in:
lib/chainer/training/triggers/interval.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(period, unit) ⇒ IntervalTrigger

Returns a new instance of IntervalTrigger.



7
8
9
10
11
12
13
14
# File 'lib/chainer/training/triggers/interval.rb', line 7

def initialize(period, unit)
  @period = period
  @unit = unit
  @count = 0

  @previous_iteration = 0
  @previous_epoch_detail = 0.0
end

Instance Attribute Details

#countObject (readonly)

Returns the value of attribute count.



5
6
7
# File 'lib/chainer/training/triggers/interval.rb', line 5

def count
  @count
end

#periodObject (readonly)

Returns the value of attribute period.



5
6
7
# File 'lib/chainer/training/triggers/interval.rb', line 5

def period
  @period
end

#unitObject (readonly)

Returns the value of attribute unit.



5
6
7
# File 'lib/chainer/training/triggers/interval.rb', line 5

def unit
  @unit
end

Instance Method Details

#call(trainer) ⇒ boolean

Decides whether the extension should be called on this iteration.

Parameters:

  • trainer (Chainer::Trainer)

    Trainer object that this trigger is associated with. The updater associated with this trainer is used to determine if the trigger should fire.

Returns:

  • (boolean)

    True if the corresponding extension should be invoked in this iteration.



21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# File 'lib/chainer/training/triggers/interval.rb', line 21

def call(trainer)
  updater = trainer.updater
  if @unit == 'epoch'
    epoch_detail = updater.epoch_detail
    previous_epoch_detail = @previous_epoch_detail

    if previous_epoch_detail < 0
      previous_epoch_detail = updater.previous_epoch_detail
    end

    @count = epoch_detail.div(@period).floor

    fire = previous_epoch_detail.div(@period).floor != epoch_detail.div(@period).floor
  else
    iteration = updater.iteration
    previous_iteration = @previous_iteration
    if previous_iteration < 0
      previous_iteration = iteration - 1
    end
    fire = previous_iteration.div(@period).floor != iteration.div(@period).floor
  end

  # save current values
  @previous_iteration = updater.iteration
  @previous_epoch_detail = updater.epoch_detail

  fire
end

#serialize(serializer) ⇒ Object



50
51
52
53
# File 'lib/chainer/training/triggers/interval.rb', line 50

def serialize(serializer)
  @previous_iteration = serializer.('previous_iteration', @previous_iteration)
  @previous_epoch_detail = serializer.('previous_epoch_detail', @previous_epoch_detail)
end