56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
# File 'lib/chainer/functions/pooling/max_pooling_2d.rb', line 56
def forward(gy)
n, c, out_h, out_w = gy[0].shape
h, w = @in_shape[2..-1]
kh, kw = @kh, @kw
gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)
indexes = @indexes.flatten
indexes += indexes.class.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)
gcol[indexes] = gy[0].flatten.dup
gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
gcol = gcol.swapaxes(2, 4)
gcol = gcol.swapaxes(3, 5)
gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
[gx]
end
|