Class: Chainer::Training::Extensions::Snapshot

Inherits:
Chainer::Training::Extension show all
Defined in:
lib/chainer/training/extensions/snapshot.rb

Constant Summary

Constants inherited from Chainer::Training::Extension

Chainer::Training::Extension::PRIORITY_EDITOR, Chainer::Training::Extension::PRIORITY_READER, Chainer::Training::Extension::PRIORITY_WRITER

Instance Attribute Summary collapse

Attributes inherited from Chainer::Training::Extension

#name, #priority, #trigger

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::Training::Extension

#default_name

Constructor Details

#initialize(save_class: nil, filename_proc: nil, target: nil) ⇒ Snapshot

Returns a new instance of Snapshot.



15
16
17
18
19
20
21
# File 'lib/chainer/training/extensions/snapshot.rb', line 15

def initialize(save_class: nil, filename_proc: nil, target: nil)
  @priority = -100
  @trigger = [1, 'epoch']
  @save_class = save_class || Chainer::Serializers::MarshalSerializer
  @filename_proc = filename_proc || Proc.new { |trainer| "snapshot_iter_#{trainer.updater.iteration}" }
  @target = target
end

Instance Attribute Details

#filename_procObject

Returns the value of attribute filename_proc.



5
6
7
# File 'lib/chainer/training/extensions/snapshot.rb', line 5

def filename_proc
  @filename_proc
end

#save_classObject

Returns the value of attribute save_class.



5
6
7
# File 'lib/chainer/training/extensions/snapshot.rb', line 5

def save_class
  @save_class
end

#targetObject

Returns the value of attribute target.



5
6
7
# File 'lib/chainer/training/extensions/snapshot.rb', line 5

def target
  @target
end

Class Method Details

.snapshot(save_class: nil, &block) ⇒ Object



11
12
13
# File 'lib/chainer/training/extensions/snapshot.rb', line 11

def self.snapshot(save_class: nil, &block)
  self.new(save_class: save_class, filename_proc: block)
end

.snapshot_object(target:, save_class:, &block) ⇒ Object



7
8
9
# File 'lib/chainer/training/extensions/snapshot.rb', line 7

def self.snapshot_object(target:, save_class:, &block)
  self.new(save_class: save_class, filename_proc: block, target: target)
end

Instance Method Details

#call(trainer) ⇒ Object



23
24
25
26
27
28
29
30
# File 'lib/chainer/training/extensions/snapshot.rb', line 23

def call(trainer)
  target = @target || trainer
  filename = filename_proc.call(trainer)
  prefix = "tmp#{filename}"
  temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
  save_class.save_file(temp_file.path, trainer)
  FileUtils.move(temp_file.path, File.join(trainer.out, filename))
end