Class: Chainer::Functions::Connection::LinearFunction

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/connection/linear.rb

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #initialize, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

This class inherits a constructor from Chainer::FunctionNode

Class Method Details

.linear(x, w, b = nil) ⇒ Object



5
6
7
8
9
10
11
12
13
14
15
16
17
# File 'lib/chainer/functions/connection/linear.rb', line 5

def self.linear(x, w, b=nil)
  if x.ndim > 2
    x = x.reshape(x.shape.first, -1)
  end

  if b.nil?
    args = x, w
  else
    args = x, w, b
  end

  self.new.apply(args).first
end

Instance Method Details

#backward(indexes, grad_outputs) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# File 'lib/chainer/functions/connection/linear.rb', line 33

def backward(indexes, grad_outputs)
  x, w = get_retained_inputs
  gy = grad_outputs.first

  ret = []
  if indexes.include?(0)
    gx = LinearFunction.linear(gy, w.transpose)
    ret << Chainer::Functions::Array::Cast.cast(gx, x.dtype)
  end
  if indexes.include?(1)
    gw = LinearFunction.linear(gy.transpose, x.transpose)
    ret << Chainer::Functions::Array::Cast.cast(gw, w.dtype)
  end
  if indexes.include?(2)
    gb = Chainer::Functions::Math::Sum.sum(gy, axis: 0)
    ret << gb
  end
  ret
end

#forward(inputs) ⇒ Object



19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/chainer/functions/connection/linear.rb', line 19

def forward(inputs)
  x = inputs[0]
  w = inputs[1]

  y = x.dot(w.transpose).cast_to(x.class)
  if inputs.size == 3
    b = inputs[2]
    y += b
  end

  retain_inputs([0, 1])
  return [y]
end