Class: Chainer::Functions::Connection::LinearFunction
Instance Attribute Summary
#inputs, #outputs, #rank
Class Method Summary
collapse
Instance Method Summary
collapse
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #initialize, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Class Method Details
.linear(x, w, b = nil) ⇒ Object
5
6
7
8
9
10
11
12
13
14
15
16
17
|
# File 'lib/chainer/functions/connection/linear.rb', line 5
def self.linear(x, w, b=nil)
if x.ndim > 2
x = x.reshape(x.shape.first, -1)
end
if b.nil?
args = x, w
else
args = x, w, b
end
self.new.apply(args).first
end
|
Instance Method Details
#backward(indexes, grad_outputs) ⇒ Object
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
# File 'lib/chainer/functions/connection/linear.rb', line 33
def backward(indexes, grad_outputs)
x, w = get_retained_inputs
gy = grad_outputs.first
ret = []
if indexes.include?(0)
gx = LinearFunction.linear(gy, w.transpose)
ret << Chainer::Functions::Array::Cast.cast(gx, x.dtype)
end
if indexes.include?(1)
gw = LinearFunction.linear(gy.transpose, x.transpose)
ret << Chainer::Functions::Array::Cast.cast(gw, w.dtype)
end
if indexes.include?(2)
gb = Chainer::Functions::Math::Sum.sum(gy, axis: 0)
ret << gb
end
ret
end
|
#forward(inputs) ⇒ Object
19
20
21
22
23
24
25
26
27
28
29
30
31
|
# File 'lib/chainer/functions/connection/linear.rb', line 19
def forward(inputs)
x = inputs[0]
w = inputs[1]
y = x.dot(w.transpose).cast_to(x.class)
if inputs.size == 3
b = inputs[2]
y += b
end
retain_inputs([0, 1])
return [y]
end
|