Class: Chainer::Functions::Connection::LinearFunction
- Inherits:
-
Chainer::Function
- Object
- Chainer::Function
- Chainer::Functions::Connection::LinearFunction
- Defined in:
- lib/chainer/functions/connection/linear.rb
Instance Attribute Summary
Attributes inherited from Chainer::Function
#inputs, #output_data, #outputs, #rank, #retain_after_backward
Class Method Summary collapse
Instance Method Summary collapse
Methods inherited from Chainer::Function
#call, #forward_cpu, #initialize, #retain_inputs, #retain_outputs
Constructor Details
This class inherits a constructor from Chainer::Function
Class Method Details
.linear(x, w, b = nil) ⇒ Object
5 6 7 8 9 10 11 |
# File 'lib/chainer/functions/connection/linear.rb', line 5 def self.linear(x, w, b=nil) if b.nil? self.new.(x, w) else self.new.(x, w, b) end end |
Instance Method Details
#backward(inputs, grad_outputs) ⇒ Object
25 26 27 28 29 30 31 32 33 34 35 36 37 |
# File 'lib/chainer/functions/connection/linear.rb', line 25 def backward(inputs, grad_outputs) x = as_mat(inputs[0]) w = inputs[1] gy = grad_outputs[0] gx = gy.dot(w).cast_to(x.class).reshape(*inputs[0].shape) gw = gy.transpose.dot(x).cast_to(w.class) if inputs.size == 3 gb = gy.sum(0) [gx, gw, gb] else [gx, gw] end end |
#forward(inputs) ⇒ Object
13 14 15 16 17 18 19 20 21 22 23 |
# File 'lib/chainer/functions/connection/linear.rb', line 13 def forward(inputs) x = as_mat(inputs[0]) w = inputs[1] y = x.dot(w.transpose).cast_to(x.class) if inputs.size == 3 b = inputs[2] y += b end return [y] end |