Class: Chainer::Functions::Connection::EmbedIDFunction
- Inherits:
-
Chainer::Function
- Object
- Chainer::Function
- Chainer::Functions::Connection::EmbedIDFunction
- Defined in:
- lib/chainer/functions/connection/embed_id.rb
Instance Attribute Summary
Attributes inherited from Chainer::Function
#inputs, #output_data, #outputs, #owned_node, #rank, #retain_after_backward
Class Method Summary collapse
Instance Method Summary collapse
- #backward(inputs, grad_outputs) ⇒ Object
- #forward(inputs) ⇒ Object
-
#initialize(ignore_label: nil) ⇒ EmbedIDFunction
constructor
A new instance of EmbedIDFunction.
Methods inherited from Chainer::Function
#backward_cpu, #backward_gpu, #call, #forward_cpu, #forward_gpu, #label, #node, #retain_inputs, #retain_outputs
Constructor Details
#initialize(ignore_label: nil) ⇒ EmbedIDFunction
Returns a new instance of EmbedIDFunction.
5 6 7 |
# File 'lib/chainer/functions/connection/embed_id.rb', line 5 def initialize(ignore_label: nil) @ignore_label = ignore_label end |
Class Method Details
.embed_id(x, w, ignore_label: nil) ⇒ Object
9 10 11 |
# File 'lib/chainer/functions/connection/embed_id.rb', line 9 def self.(x, w, ignore_label: nil) self.new(ignore_label: ignore_label).(x, w) end |
Instance Method Details
#backward(inputs, grad_outputs) ⇒ Object
34 35 36 37 38 39 40 41 42 43 44 45 |
# File 'lib/chainer/functions/connection/embed_id.rb', line 34 def backward(inputs, grad_outputs) (x, w) = inputs gy = grad_outputs[0].reshape(x.size, true) gw = w.class.zeros(w.shape).reshape(w.shape.take(w.shape.size - 1).reduce(&:*), true) x.reshape(x.size).each_with_index do |ix, i| next if ix == @ignore_label gw[ix, true] = gw[ix, true] + gy[i, true] end [nil, gw.reshape(*w.shape)] end |
#forward(inputs) ⇒ Object
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# File 'lib/chainer/functions/connection/embed_id.rb', line 13 def forward(inputs) xm = Chainer.get_array_module(*inputs) (x, w) = inputs unless @ignore_label return [Chainer::Utils::Array.take(w, x, axis: 0)] end valid_x = x.ne(@ignore_label) if valid_x.count == x.size return [Chainer::Utils::Array.take(w, x, axis: 0)] end x *= valid_x y = Chainer::Utils::Array.take(w, x, axis: 0).dup y = y.reshape(y.shape.take(y.shape.size - 1).reduce(&:*), true) valid_x.where2.last.each {|i| y[i, true] = y.class.zeros(y.shape.last) } [y.reshape(*x.shape, true)] end |