Class: Chainer::Functions::Array::SelectItem
- Inherits:
-
Chainer::FunctionNode
- Object
- Chainer::FunctionNode
- Chainer::Functions::Array::SelectItem
- Defined in:
- lib/chainer/functions/array/select_item.rb
Overview
Select elements stored in given indices.
Instance Attribute Summary
Attributes inherited from Chainer::FunctionNode
Class Method Summary collapse
-
.select_item(x, t) ⇒ Object
Select elements stored in given indices.
Instance Method Summary collapse
Methods inherited from Chainer::FunctionNode
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #initialize, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Constructor Details
This class inherits a constructor from Chainer::FunctionNode
Class Method Details
.select_item(x, t) ⇒ Object
Select elements stored in given indices.
This function returns $t.choose(x.T)$, that means
$y[i] == x[i, t[i]]$ for all $i$.
@param [Chainer::Variable] x Variable storing arrays.
@param [Chainer::Variable] t Variable storing index numbers.
@return [Chainer::Variable] Variable that holds $t$-th element of $x$.
13 14 15 |
# File 'lib/chainer/functions/array/select_item.rb', line 13 def self.select_item(x, t) SelectItem.new.apply([x, t]).first end |
Instance Method Details
#backward(indexes, gy) ⇒ Object
33 34 35 36 37 38 39 40 41 42 43 44 |
# File 'lib/chainer/functions/array/select_item.rb', line 33 def backward(indexes, gy) t = get_retained_inputs.first ret = [] if indexes.include?(0) ggx = Assign.new(@in_shape, @in_dtype, t).apply(gy).first ret << ggx end if indexes.include?(1) ret << nil end ret end |
#forward(inputs) ⇒ Object
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# File 'lib/chainer/functions/array/select_item.rb', line 17 def forward(inputs) retain_inputs([1]) x, t = inputs @in_shape = x.shape @in_dtype = x.class # TODO: x[six.moves.range(t.size), t] new_x = x.class.zeros(t.size) t.size.times.each do |i| new_x[i] = x[i, t[i]] end x = new_x [x] end |