Class: Chainer::Iterators::SerialIterator

Inherits:
Dataset::Iterator show all
Defined in:
lib/chainer/iterators/serial_iterator.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Methods inherited from Dataset::Iterator

#finalize

Constructor Details

#initialize(dataset, batch_size, repeat: true, shuffle: true, device: Chainer::Device.default) ⇒ SerialIterator

Returns a new instance of SerialIterator.



6
7
8
9
10
11
12
13
14
15
# File 'lib/chainer/iterators/serial_iterator.rb', line 6

def initialize(dataset, batch_size, repeat: true, shuffle: true, device: Chainer::Device.default)
  @dataset = dataset
  @batch_size = batch_size
  @repeat = repeat
  @shuffle = shuffle
  @device = device
  @xm = device.xm

  reset
end

Instance Attribute Details

#epochObject (readonly)

Returns the value of attribute epoch.



4
5
6
# File 'lib/chainer/iterators/serial_iterator.rb', line 4

def epoch
  @epoch
end

#is_new_epochObject (readonly)

Returns the value of attribute is_new_epoch.



4
5
6
# File 'lib/chainer/iterators/serial_iterator.rb', line 4

def is_new_epoch
  @is_new_epoch
end

Instance Method Details

#epoch_detailObject



56
57
58
# File 'lib/chainer/iterators/serial_iterator.rb', line 56

def epoch_detail
  @epoch + @current_position.to_f / @dataset.size
end

#nextObject

Raises:

  • (StopIteration)


17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# File 'lib/chainer/iterators/serial_iterator.rb', line 17

def next
  raise StopIteration if !@repeat && @epoch > 0

  @previous_epoch_detail = epoch_detail

  i = @current_position
  n = @dataset.size
  i_end = [i + @batch_size, n].min

  batch = @order[i...i_end].to_a.map { |index| @dataset[index] }

  if i_end >= n
    if @repeat
      rest = i_end - n
      unless @order.nil?
        @order = @order.class[*@order.to_a.shuffle]
      end
      if rest > 0
        if @order.nil?
          batch = batch.append(@dataset[0...rest])
        else
          batch = @dataset[0...rest].map { |index| @dataset[index] }
        end
      end
      @current_position = rest
    else
      @current_position = 0
    end

    @epoch += 1
    @is_new_epoch = true
  else
    @is_new_epoch = false
    @current_position = i_end
  end

  batch
end

#resetObject



85
86
87
88
89
90
91
92
93
94
95
96
97
98
# File 'lib/chainer/iterators/serial_iterator.rb', line 85

def reset
  if @shuffle
    order = @dataset.size.times.map(&:to_i).shuffle
    @order = @xm::Int64[*order]
  else
    order = @dataset.size.times.map(&:to_i)
    @order = @xm::Int64[*order]
  end

  @current_position = 0
  @epoch = 0
  @is_new_epoch = false
  @previous_epoch_detail = -1.0
end

#serialize(serializer) ⇒ Object



60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# File 'lib/chainer/iterators/serial_iterator.rb', line 60

def serialize(serializer)
  @current_position = serializer.('current_position', @current_position)
  @epoch = serializer.('epoch', @epoch)
  @is_new_epoch = serializer.('is_new_epoch', @is_new_epoch)
  unless @order.nil?
    begin
      serializer.('order', @order)
    rescue KeyError
      serializer('_order', @order)
    end
  end

  begin
    @previous_epoch_detail = serializer.( 'previous_epoch_detail', @previous_epoch_detail)
  rescue KeyError
    # guess previous_epoch_detail for older version
    @previous_epoch_detail = @epoch + (@current_position - @batch_size) / @dataset.size
    if epoch_detail > 0
      @previous_epoch_detail = [@previous_epoch_detail, 0.0].max
    else
      @previous_epoch_detail = -1.0
    end
  end
end