Class: Chainer::Training::StandardUpdater

Inherits:
Updater
  • Object
show all
Defined in:
lib/chainer/training/standard_updater.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods inherited from Updater

#bind, #connect_trainer

Constructor Details

#initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil) ⇒ StandardUpdater

Returns a new instance of StandardUpdater.



6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# File 'lib/chainer/training/standard_updater.rb', line 6

def initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil)
  if iterator.kind_of?(Dataset::Iterator)
    iterator = { main: iterator }
  end
  @iterators = iterator

  unless optimizer.kind_of?(Hash)
    optimizer = { main: optimizer }
  end
  @optimizers = optimizer

  @converter = converter || Dataset::Convert.method(:concat_examples)
  @loss_func = loss_func
  @device = device
  @iteration = 0
end

Instance Attribute Details

#iterationObject

Returns the value of attribute iteration.



4
5
6
# File 'lib/chainer/training/standard_updater.rb', line 4

def iteration
  @iteration
end

Instance Method Details

#epochObject



36
37
38
# File 'lib/chainer/training/standard_updater.rb', line 36

def epoch
  @iterators[:main].epoch
end

#epoch_detailObject



40
41
42
# File 'lib/chainer/training/standard_updater.rb', line 40

def epoch_detail
  @iterators[:main].epoch_detail
end

#finalizeObject



60
61
62
63
64
# File 'lib/chainer/training/standard_updater.rb', line 60

def finalize
  @iterators.each do |(_, iterator)|
    iterator.finalize
  end
end

#get_all_optimizersObject



27
28
29
# File 'lib/chainer/training/standard_updater.rb', line 27

def get_all_optimizers
  @optimizers.to_h
end

#get_optimizer(name) ⇒ Object



23
24
25
# File 'lib/chainer/training/standard_updater.rb', line 23

def get_optimizer(name)
  @optimizers[name]
end

#serialize(serializer) ⇒ Object



66
67
68
69
70
71
72
73
74
75
76
# File 'lib/chainer/training/standard_updater.rb', line 66

def serialize(serializer)
  @iterators.each do |name, iterator|
    iterator.serialize(serializer["iterator:#{name}"])
  end
  @optimizers.each do |name, optimizer|
    optimizer.serialize(serializer["optimizer:#{name}"])
    optimizer.target.serialize(serializer["model:#{name}"])
  end

  @iteration = serializer.('iteration', @iteration)
end

#updateObject



31
32
33
34
# File 'lib/chainer/training/standard_updater.rb', line 31

def update
  update_core
  @iteration += 1
end

#update_coreObject



44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# File 'lib/chainer/training/standard_updater.rb', line 44

def update_core
  batch = @iterators[:main].next
  in_arrays = @converter.call(batch, device: @device)

  optimizer = @optimizers[:main]
  loss_func = @loss_func || optimizer.target

  if in_arrays.kind_of?(Array)
    optimizer.update(loss_func, *in_arrays)
  elsif in_arrays.kind_of?(Hash)
    optimizer.update(loss_func, **in_arrays)
  else
    optimizer.update(loss_func, in_arrays)
  end
end