Class: Chainer::Functions::Connection::LinearFunction

Inherits:
Chainer::Function show all
Defined in:
lib/chainer/functions/connection/linear.rb

Instance Attribute Summary

Attributes inherited from Chainer::Function

#inputs, #output_data, #outputs, #rank, #retain_after_backward

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::Function

#call, #forward_cpu, #initialize, #retain_inputs, #retain_outputs

Constructor Details

This class inherits a constructor from Chainer::Function

Class Method Details

.linear(x, w, b = nil) ⇒ Object



5
6
7
8
9
10
11
# File 'lib/chainer/functions/connection/linear.rb', line 5

def self.linear(x, w, b=nil)
  if b.nil?
    self.new.(x, w)
  else
    self.new.(x, w, b)
  end
end

Instance Method Details

#backward(inputs, grad_outputs) ⇒ Object



25
26
27
28
29
30
31
32
33
34
35
36
37
# File 'lib/chainer/functions/connection/linear.rb', line 25

def backward(inputs, grad_outputs)
  x = as_mat(inputs[0])
  w = inputs[1]
  gy = grad_outputs[0]
  gx = gy.dot(w).cast_to(x.class).reshape(*inputs[0].shape)
  gw = gy.transpose.dot(x).cast_to(w.class)
  if inputs.size == 3
    gb = gy.sum(0)
    [gx, gw, gb]
  else
    [gx, gw]
  end
end

#forward(inputs) ⇒ Object



13
14
15
16
17
18
19
20
21
22
23
# File 'lib/chainer/functions/connection/linear.rb', line 13

def forward(inputs)
  x = as_mat(inputs[0])
  w = inputs[1]

  y = x.dot(w.transpose).cast_to(x.class)
  if inputs.size == 3
    b = inputs[2]
    y += b
  end
  return [y]
end