Class: Chainer::Functions::Pooling::MaxPooling2DGrad

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/pooling/max_pooling_2d.rb

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(mpool2d) ⇒ MaxPooling2DGrad

Returns a new instance of MaxPooling2DGrad.



42
43
44
45
46
47
48
49
50
51
52
53
54
# File 'lib/chainer/functions/pooling/max_pooling_2d.rb', line 42

def initialize(mpool2d)
  @kh = mpool2d.kh
  @kw = mpool2d.kw
  @sy = mpool2d.sy
  @sx = mpool2d.sx
  @ph = mpool2d.ph
  @pw = mpool2d.pw
  @cover_all = mpool2d.cover_all
  @indexes = mpool2d.indexes
  @in_shape = mpool2d.in_shape
  @in_dtype = mpool2d.in_dtype
  @mpool2d = mpool2d
end

Instance Method Details

#backward(indexes, ggx) ⇒ Object



75
76
77
# File 'lib/chainer/functions/pooling/max_pooling_2d.rb', line 75

def backward(indexes, ggx)
  MaxPooling2DWithIndexes.new(@mpool2d).apply(ggx)
end

#forward(gy) ⇒ Object



56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# File 'lib/chainer/functions/pooling/max_pooling_2d.rb', line 56

def forward(gy)
  n, c, out_h, out_w = gy[0].shape
  h, w  = @in_shape[2..-1]
  kh, kw = @kh, @kw

  gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)

  indexes = @indexes.flatten
  indexes += indexes.class.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)

  gcol[indexes] = gy[0].flatten.dup
  gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
  gcol = gcol.swapaxes(2, 4)
  gcol = gcol.swapaxes(3, 5)

  gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
  [gx]
end