Class: Chainer::Functions::Connection::Convolution2DGradW
- Inherits:
-
Chainer::FunctionNode
- Object
- Chainer::FunctionNode
- Chainer::Functions::Connection::Convolution2DGradW
- Defined in:
- lib/chainer/functions/connection/convolution_2d_grad_w.rb
Instance Attribute Summary
Attributes inherited from Chainer::FunctionNode
Instance Method Summary collapse
- #backward(indexes, grad_outputs) ⇒ Object
- #forward(inputs) ⇒ Object
-
#initialize(conv2d) ⇒ Convolution2DGradW
constructor
A new instance of Convolution2DGradW.
Methods inherited from Chainer::FunctionNode
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Constructor Details
#initialize(conv2d) ⇒ Convolution2DGradW
Returns a new instance of Convolution2DGradW.
5 6 7 8 9 10 11 12 13 14 15 |
# File 'lib/chainer/functions/connection/convolution_2d_grad_w.rb', line 5 def initialize(conv2d) w_node = conv2d.inputs[1] @kh, @kw = w_node.shape[2..-1] @sy = conv2d.sy @sx = conv2d.sx @ph = conv2d.ph @pw = conv2d.pw @cover_all = conv2d.cover_all @w_dtype = w_node.dtype end |
Instance Method Details
#backward(indexes, grad_outputs) ⇒ Object
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# File 'lib/chainer/functions/connection/convolution_2d_grad_w.rb', line 26 def backward(indexes, grad_outputs) x, gy = get_retained_inputs ggw = grad_outputs.first ret = [] if indexes.include?(0) xh, xw = x.shape[2..-1] gx = Deconvolution2DFunction.deconvolution_2d(gy, ggw, stride: [@sy, @sx], pad: [@ph, @pw], outsize: [xh, xw]) ret << gx end if indexes.include?(1) ggy = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(x, ggw, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all) ret << ggy end ret end |
#forward(inputs) ⇒ Object
17 18 19 20 21 22 23 24 |
# File 'lib/chainer/functions/connection/convolution_2d_grad_w.rb', line 17 def forward(inputs) retain_inputs([0, 1]) x, gy = inputs col = Chainer::Utils::Conv.im2col(x, @kh, @kw, @sy, @sx, @ph, @pw, cover_all: @cover_all) gw = Chainer::Utils::Math.tensordot(gy, col, [[0, 2, 3], [0, 4, 5]]).cast_to(@w_dtype) [gw] end |