Class: Chainer::Training::Extensions::PrintReport

Inherits:
Chainer::Training::Extension show all
Defined in:
lib/chainer/training/extensions/print_report.rb

Constant Summary

Constants inherited from Chainer::Training::Extension

Chainer::Training::Extension::PRIORITY_EDITOR, Chainer::Training::Extension::PRIORITY_READER, Chainer::Training::Extension::PRIORITY_WRITER

Instance Attribute Summary

Attributes inherited from Chainer::Training::Extension

#name, #priority, #trigger

Instance Method Summary collapse

Methods inherited from Chainer::Training::Extension

#default_name

Constructor Details

#initialize(entries, log_report: 'LogReport', out: STDOUT) ⇒ PrintReport

Returns a new instance of PrintReport.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# File 'lib/chainer/training/extensions/print_report.rb', line 5

def initialize(entries, log_report: 'LogReport', out: STDOUT)
  @entries = entries
  @log_report = log_report
  @out = out

  @log_len = 0 # number of observations already printed

  # format information
  entry_widths = entries.map { |s| [10, s.size].max }

  templates = []
  header = []
  entries.zip(entry_widths).each do |entry, w|
    header << sprintf("%-#{w}s", entry)
    templates << [entry, "%-#{w}g  ", ' ' * (w + 2)]
  end
  @header = header.join('  ') + "\n"
  @templates = templates
end

Instance Method Details

#call(trainer) ⇒ Object



25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/chainer/training/extensions/print_report.rb', line 25

def call(trainer)
  if @header
    @out.write(@header)
    @header = nil
  end

  if @log_report.is_a?(String)
    log_report = trainer.get_extension(@log_report)
  elsif @log_report.is_a?(LogReport)
    log_report.(trainer)
  else
    raise TypeError, "log report has a wrong type #{log_report.class}"
  end

  log = log_report.log
  while log.size > @log_len
    @out.write("\033[J")
    print(log[@log_len])
    @log_len += 1
  end
end

#serialize(serializer) ⇒ Object



47
48
49
50
51
# File 'lib/chainer/training/extensions/print_report.rb', line 47

def serialize(serializer)
  if @log_report.is_a?(Chainer::Training::Extensions::LogReport)
    @log_report.serialize(serializer['_log_report'])
  end
end