Class: Chainer::Functions::Pooling::MaxPooling2DWithIndexes

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/pooling/max_pooling_2d.rb

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(mpool2d) ⇒ MaxPooling2DWithIndexes

Returns a new instance of MaxPooling2DWithIndexes.



81
82
83
84
85
86
87
88
89
90
# File 'lib/chainer/functions/pooling/max_pooling_2d.rb', line 81

def initialize(mpool2d)
  @kh = mpool2d.kh
  @kw = mpool2d.kw
  @sy = mpool2d.sy
  @sx = mpool2d.sx
  @ph = mpool2d.ph
  @pw = mpool2d.pw
  @cover_all = mpool2d.cover_all
  @indexes = mpool2d.indexes
end

Instance Method Details

#forward(x) ⇒ Object



92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# File 'lib/chainer/functions/pooling/max_pooling_2d.rb', line 92

def forward(x)
  col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
  n, c, kh, kw, out_h, out_w = col.shape
  col = col.reshape(n, c, kh * kw, out_h, out_w)
  col = col.transpose(0, 1, 3, 4, 2).reshape(nil, kh * kw)

  indexes = @indexes.flatten.dup

  # TODO: col = col[numpy.arange(len(indexes)), indexes]
  new_col = col.class.zeros(indexes.size)
  x[0].class.new(indexes.size).seq.each_with_index do |v, i|
    new_col[i] = col[v, indexes[i]]
  end
  col = new_col

  [col.reshape(n, c, out_h, out_w)]
end