Class: Chainer::Training::StandardUpdater
- Defined in:
- lib/chainer/training/standard_updater.rb
Instance Attribute Summary collapse
-
#iteration ⇒ Object
Returns the value of attribute iteration.
Instance Method Summary collapse
- #epoch ⇒ Object
- #epoch_detail ⇒ Object
- #finalize ⇒ Object
- #get_all_optimizers ⇒ Object
- #get_optimizer(name) ⇒ Object
-
#initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil) ⇒ StandardUpdater
constructor
A new instance of StandardUpdater.
- #serialize(serializer) ⇒ Object
- #update ⇒ Object
- #update_core ⇒ Object
Methods inherited from Updater
Constructor Details
#initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil) ⇒ StandardUpdater
Returns a new instance of StandardUpdater.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
# File 'lib/chainer/training/standard_updater.rb', line 6 def initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil) if iterator.kind_of?(Dataset::Iterator) iterator = { main: iterator } end @iterators = iterator unless optimizer.kind_of?(Hash) optimizer = { main: optimizer } end @optimizers = optimizer @converter = converter || Dataset::Convert.method(:concat_examples) @loss_func = loss_func @device = device @iteration = 0 end |
Instance Attribute Details
#iteration ⇒ Object
Returns the value of attribute iteration.
4 5 6 |
# File 'lib/chainer/training/standard_updater.rb', line 4 def iteration @iteration end |
Instance Method Details
#epoch ⇒ Object
36 37 38 |
# File 'lib/chainer/training/standard_updater.rb', line 36 def epoch @iterators[:main].epoch end |
#epoch_detail ⇒ Object
40 41 42 |
# File 'lib/chainer/training/standard_updater.rb', line 40 def epoch_detail @iterators[:main].epoch_detail end |
#finalize ⇒ Object
60 61 62 63 64 |
# File 'lib/chainer/training/standard_updater.rb', line 60 def finalize @iterators.each do |(_, iterator)| iterator.finalize end end |
#get_all_optimizers ⇒ Object
27 28 29 |
# File 'lib/chainer/training/standard_updater.rb', line 27 def get_all_optimizers @optimizers.to_h end |
#get_optimizer(name) ⇒ Object
23 24 25 |
# File 'lib/chainer/training/standard_updater.rb', line 23 def get_optimizer(name) @optimizers[name] end |
#serialize(serializer) ⇒ Object
66 67 68 69 70 71 72 73 74 75 76 |
# File 'lib/chainer/training/standard_updater.rb', line 66 def serialize(serializer) @iterators.each do |name, iterator| iterator.serialize(serializer["iterator:#{name}"]) end @optimizers.each do |name, optimizer| optimizer.serialize(serializer["optimizer:#{name}"]) optimizer.target.serialize(serializer["model:#{name}"]) end @iteration = serializer.('iteration', @iteration) end |
#update ⇒ Object
31 32 33 34 |
# File 'lib/chainer/training/standard_updater.rb', line 31 def update update_core @iteration += 1 end |
#update_core ⇒ Object
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
# File 'lib/chainer/training/standard_updater.rb', line 44 def update_core batch = @iterators[:main].next in_arrays = @converter.call(batch, device: @device) optimizer = @optimizers[:main] loss_func = @loss_func || optimizer.target if in_arrays.kind_of?(Array) optimizer.update(loss_func, *in_arrays) elsif in_arrays.kind_of?(Hash) optimizer.update(loss_func, **in_arrays) else optimizer.update(loss_func, in_arrays) end end |