Class: Chainer::Functions::Connection::Convolution2DGradW

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/connection/convolution_2d_grad_w.rb

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(conv2d) ⇒ Convolution2DGradW

Returns a new instance of Convolution2DGradW.



5
6
7
8
9
10
11
12
13
14
15
# File 'lib/chainer/functions/connection/convolution_2d_grad_w.rb', line 5

def initialize(conv2d)
  w_node = conv2d.inputs[1]

  @kh, @kw = w_node.shape[2..-1]
  @sy = conv2d.sy
  @sx = conv2d.sx
  @ph = conv2d.ph
  @pw = conv2d.pw
  @cover_all = conv2d.cover_all
  @w_dtype = w_node.dtype
end

Instance Method Details

#backward(indexes, grad_outputs) ⇒ Object



26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/chainer/functions/connection/convolution_2d_grad_w.rb', line 26

def backward(indexes, grad_outputs)
  x, gy = get_retained_inputs
  ggw = grad_outputs.first

  ret = []
  if indexes.include?(0)
    xh, xw = x.shape[2..-1]
    gx = Deconvolution2DFunction.deconvolution_2d(gy, ggw, stride: [@sy, @sx], pad: [@ph, @pw], outsize: [xh, xw])
    ret << gx
  end

  if indexes.include?(1)
    ggy = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(x, ggw, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all)
    ret << ggy
  end

  ret
end

#forward(inputs) ⇒ Object



17
18
19
20
21
22
23
24
# File 'lib/chainer/functions/connection/convolution_2d_grad_w.rb', line 17

def forward(inputs)
  retain_inputs([0, 1])
  x, gy = inputs
  col = Chainer::Utils::Conv.im2col(x, @kh, @kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)

  gw = Chainer::Utils::Math.tensordot(gy, col, [[0, 2, 3], [0, 4, 5]]).cast_to(@w_dtype)
  [gw]
end