Class: Chainer::Functions::Connection::EmbedIDFunction

Inherits:
Chainer::Function show all
Defined in:
lib/chainer/functions/connection/embed_id.rb

Instance Attribute Summary

Attributes inherited from Chainer::Function

#inputs, #output_data, #outputs, #owned_node, #rank, #retain_after_backward

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::Function

#backward_cpu, #backward_gpu, #call, #forward_cpu, #forward_gpu, #label, #node, #retain_inputs, #retain_outputs

Constructor Details

#initialize(ignore_label: nil) ⇒ EmbedIDFunction

Returns a new instance of EmbedIDFunction.



5
6
7
# File 'lib/chainer/functions/connection/embed_id.rb', line 5

def initialize(ignore_label: nil)
  @ignore_label = ignore_label
end

Class Method Details

.embed_id(x, w, ignore_label: nil) ⇒ Object



9
10
11
# File 'lib/chainer/functions/connection/embed_id.rb', line 9

def self.embed_id(x, w, ignore_label: nil)
  self.new(ignore_label: ignore_label).(x, w)
end

Instance Method Details

#backward(inputs, grad_outputs) ⇒ Object



34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/chainer/functions/connection/embed_id.rb', line 34

def backward(inputs, grad_outputs)
  (x, w) = inputs
  gy = grad_outputs[0].reshape(x.size, true)
  gw = w.class.zeros(w.shape).reshape(w.shape.take(w.shape.size - 1).reduce(&:*), true)

  x.reshape(x.size).each_with_index do |ix, i|
    next if ix == @ignore_label
    gw[ix, true] = gw[ix, true] + gy[i, true]
  end

  [nil, gw.reshape(*w.shape)]
end

#forward(inputs) ⇒ Object



13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# File 'lib/chainer/functions/connection/embed_id.rb', line 13

def forward(inputs)
  xm = Chainer.get_array_module(*inputs)
  (x, w) = inputs

  unless @ignore_label
    return [Chainer::Utils::Array.take(w, x, axis: 0)]
  end

  valid_x = x.ne(@ignore_label)
  if valid_x.count == x.size
    return [Chainer::Utils::Array.take(w, x, axis: 0)]
  end
  x *= valid_x
  y = Chainer::Utils::Array.take(w, x, axis: 0).dup

  y = y.reshape(y.shape.take(y.shape.size - 1).reduce(&:*), true)
  valid_x.where2.last.each {|i| y[i, true] = y.class.zeros(y.shape.last) }

  [y.reshape(*x.shape, true)]
end