Class: Chainer::Functions::Connection::EmbedIDFunction
Instance Attribute Summary
#inputs, #output_data, #outputs, #owned_node, #rank, #retain_after_backward
Class Method Summary
collapse
Instance Method Summary
collapse
#backward_cpu, #backward_gpu, #call, #forward_cpu, #forward_gpu, #label, #node, #retain_inputs, #retain_outputs
Constructor Details
#initialize(ignore_label: nil) ⇒ EmbedIDFunction
5
6
7
|
# File 'lib/chainer/functions/connection/embed_id.rb', line 5
def initialize(ignore_label: nil)
@ignore_label = ignore_label
end
|
Class Method Details
.embed_id(x, w, ignore_label: nil) ⇒ Object
9
10
11
|
# File 'lib/chainer/functions/connection/embed_id.rb', line 9
def self.embed_id(x, w, ignore_label: nil)
self.new(ignore_label: ignore_label).(x, w)
end
|
Instance Method Details
#backward(inputs, grad_outputs) ⇒ Object
34
35
36
37
38
39
40
41
42
43
44
45
|
# File 'lib/chainer/functions/connection/embed_id.rb', line 34
def backward(inputs, grad_outputs)
(x, w) = inputs
gy = grad_outputs[0].reshape(x.size, true)
gw = w.class.zeros(w.shape).reshape(w.shape.take(w.shape.size - 1).reduce(&:*), true)
x.reshape(x.size).each_with_index do |ix, i|
next if ix == @ignore_label
gw[ix, true] = gw[ix, true] + gy[i, true]
end
[nil, gw.reshape(*w.shape)]
end
|
#forward(inputs) ⇒ Object
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
# File 'lib/chainer/functions/connection/embed_id.rb', line 13
def forward(inputs)
xm = Chainer.get_array_module(*inputs)
(x, w) = inputs
unless @ignore_label
return [Chainer::Utils::Array.take(w, x, axis: 0)]
end
valid_x = x.ne(@ignore_label)
if valid_x.count == x.size
return [Chainer::Utils::Array.take(w, x, axis: 0)]
end
x *= valid_x
y = Chainer::Utils::Array.take(w, x, axis: 0).dup
y = y.reshape(y.shape.take(y.shape.size - 1).reduce(&:*), true)
valid_x.where2.last.each {|i| y[i, true] = y.class.zeros(y.shape.last) }
[y.reshape(*x.shape, true)]
end
|