Class: Chainer::Functions::Array::SelectItem

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/array/select_item.rb

Overview

Select elements stored in given indices.

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #initialize, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

This class inherits a constructor from Chainer::FunctionNode

Class Method Details

.select_item(x, t) ⇒ Object

Select elements stored in given indices.

This function returns $t.choose(x.T)$, that means
$y[i] == x[i, t[i]]$ for all $i$.

@param [Chainer::Variable] x Variable storing arrays.
@param [Chainer::Variable] t Variable storing index numbers.
@return [Chainer::Variable] Variable that holds $t$-th element of $x$.


13
14
15
# File 'lib/chainer/functions/array/select_item.rb', line 13

def self.select_item(x, t)
  SelectItem.new.apply([x, t]).first
end

Instance Method Details

#backward(indexes, gy) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
# File 'lib/chainer/functions/array/select_item.rb', line 33

def backward(indexes, gy)
  t = get_retained_inputs.first
  ret = []
  if indexes.include?(0)
    ggx = Assign.new(@in_shape, @in_dtype, t).apply(gy).first
    ret << ggx
  end
  if indexes.include?(1)
    ret << nil
  end
  ret
end

#forward(inputs) ⇒ Object



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/chainer/functions/array/select_item.rb', line 17

def forward(inputs)
  retain_inputs([1])
  x, t = inputs
  @in_shape = x.shape
  @in_dtype = x.class

  # TODO: x[six.moves.range(t.size), t]
  new_x = x.class.zeros(t.size)
  t.size.times.each do |i|
    new_x[i] = x[i, t[i]]
  end
  x = new_x

  [x]
end