Class: Chainer::Training::Trainer
- Inherits:
-
Object
- Object
- Chainer::Training::Trainer
- Defined in:
- lib/chainer/training/trainer.rb
Instance Attribute Summary collapse
-
#observation ⇒ Object
Returns the value of attribute observation.
-
#out ⇒ Object
Returns the value of attribute out.
-
#stop_trigger ⇒ Object
Returns the value of attribute stop_trigger.
-
#updater ⇒ Object
Returns the value of attribute updater.
Instance Method Summary collapse
- #elapsed_time ⇒ Object
- #extend(extension, name: nil, trigger: nil, priority: nil) ⇒ Object
- #get_extension(name) ⇒ Object
-
#initialize(updater, stop_trigger: nil, out: 'result') ⇒ Trainer
constructor
A new instance of Trainer.
- #run ⇒ Object
- #serialize(serializer) ⇒ Object
Constructor Details
#initialize(updater, stop_trigger: nil, out: 'result') ⇒ Trainer
Returns a new instance of Trainer.
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
# File 'lib/chainer/training/trainer.rb', line 16 def initialize(updater, stop_trigger: nil, out: 'result') @updater = updater @stop_trigger = Chainer::Training::Util.get_trigger(stop_trigger) @observation = {} @out = out reporter = Reporter.new updater.get_all_optimizers().each do |(name, optimizer)| reporter.add_observer(name, optimizer.target) optimizer.target.namedlinks(skipself: true) do |suffix, observer| observer_name = name.to_s + suffix reporter.add_observer(observer_name, observer) end end @reporter = reporter @done = false @extensions = {} @start_at = nil @snapshot_elapsed_time = 0.0 @final_elapsed_time = nil updater.connect_trainer(self) end |
Instance Attribute Details
#observation ⇒ Object
Returns the value of attribute observation.
14 15 16 |
# File 'lib/chainer/training/trainer.rb', line 14 def observation @observation end |
#out ⇒ Object
Returns the value of attribute out.
14 15 16 |
# File 'lib/chainer/training/trainer.rb', line 14 def out @out end |
#stop_trigger ⇒ Object
Returns the value of attribute stop_trigger.
14 15 16 |
# File 'lib/chainer/training/trainer.rb', line 14 def stop_trigger @stop_trigger end |
#updater ⇒ Object
Returns the value of attribute updater.
14 15 16 |
# File 'lib/chainer/training/trainer.rb', line 14 def updater @updater end |
Instance Method Details
#elapsed_time ⇒ Object
42 43 44 45 46 47 |
# File 'lib/chainer/training/trainer.rb', line 42 def elapsed_time return @final_elapsed_time if @done raise "training has not been started yet" if @start_at.nil? Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f end |
#extend(extension, name: nil, trigger: nil, priority: nil) ⇒ Object
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
# File 'lib/chainer/training/trainer.rb', line 49 def extend(extension, name: nil, trigger: nil, priority: nil) if name.nil? name = if extension.name extension.name elsif extension.default_name extension.default_name else raise ArgumentError, 'name is not given for the extension' end end raise 'the name "training" is prohibited as an extension name' if name == 'training' if trigger.nil? trigger = extension.methods.include?(:trigger) ? extension.trigger : [1, 'iteration'] end trigger = Chainer::Training::Util.get_trigger(trigger) if priority.nil? priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER end modified_name = name ordinal = 0 @extensions.each do |modified_name| ordinal += 1 modified_name = "#{name}_#{ordinal}" end extension.name = modified_name @extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger) end |
#get_extension(name) ⇒ Object
83 84 85 86 87 88 89 |
# File 'lib/chainer/training/trainer.rb', line 83 def get_extension(name) if @extensions.keys.include?(name) @extensions[name].extension else raise "extension #{name} not found" end end |
#run ⇒ Object
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
# File 'lib/chainer/training/trainer.rb', line 91 def run raise 'cannot run training loop multiple times' if @done FileUtils.mkdir_p(@out) extensions = @extensions.sort_by { |(_, e)| -e.priority }.map { |(name, extension)| [name, extension] } @start_at = Time.now.to_f extensions.each do |(_, entry)| initializer = entry.extension.methods.include?(:init) ? entry.extension.method(:init) : nil initializer.call(self) if initializer end update = @updater.method(:update) reporter = @reporter stop_trigger = @stop_trigger begin until stop_trigger.(self) do @observation = {} reporter.scope(@observation) do update.call extensions.each do |(name, entry)| entry.extension.(self) if entry.trigger.(self) end end end ensure extensions.each do |(_, entry)| finalize = entry.extension.methods.include?(:finalize) ? entry.extension.method(:finalize) : nil finalize.() if finalize end @updater.finalize() end @final_elapsed_time = @elapsed_time @done = true end |
#serialize(serializer) ⇒ Object
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# File 'lib/chainer/training/trainer.rb', line 130 def serialize(serializer) updater.serialize(serializer['updater']) if @stop_trigger.respond_to?(:serialize) @stop_trigger.serialize(serializer['stop_trigger']) end s = serializer['extensions'] t = serializer['extension_triggers'] @extensions.each do |name, entry| if entry.extension.respond_to?(:serialize) entry.extension.serialize(s[name]) end if entry.trigger.respond_to?(:serialize) entry.trigger.serialize(t[name]) end end if serializer.is_a?(Chainer::Serializer) serializer.('_snapshot_elapsed_time', elapsed_time) else @snapshot_elapsed_time = serializer.('_snapshot_elapsed_time', 0.0) end end |