Class: Chainer::Functions::Array::Rollaxis
- Inherits:
-
Chainer::FunctionNode
- Object
- Chainer::FunctionNode
- Chainer::Functions::Array::Rollaxis
- Defined in:
- lib/chainer/functions/array/rollaxis.rb
Overview
Roll axis of an array.
Instance Attribute Summary
Attributes inherited from Chainer::FunctionNode
Class Method Summary collapse
-
.rollaxis(x, axis, start: 0) ⇒ Chainer::Variable
Roll the axis backwards to the given position.
Instance Method Summary collapse
- #backward(indexes, gy) ⇒ Object
- #forward(inputs) ⇒ Object
-
#initialize(axis, start) ⇒ Rollaxis
constructor
A new instance of Rollaxis.
Methods inherited from Chainer::FunctionNode
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Constructor Details
#initialize(axis, start) ⇒ Rollaxis
Returns a new instance of Rollaxis.
16 17 18 19 20 21 22 23 24 25 26 27 |
# File 'lib/chainer/functions/array/rollaxis.rb', line 16 def initialize(axis, start) unless axis.is_a?(Integer) raise ArgumentError, 'axis must be int' end unless start.is_a?(Integer) raise ArgumentError, 'start must be int' end @axis = axis @start = start end |
Class Method Details
.rollaxis(x, axis, start: 0) ⇒ Chainer::Variable
Roll the axis backwards to the given position.
12 13 14 |
# File 'lib/chainer/functions/array/rollaxis.rb', line 12 def self.rollaxis(x, axis, start: 0) Rollaxis.new(axis, start).apply([x]).first end |
Instance Method Details
#backward(indexes, gy) ⇒ Object
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
# File 'lib/chainer/functions/array/rollaxis.rb', line 36 def backward(indexes, gy) axis = @axis if axis < 0 axis += @in_ndim end start = @start if start < 0 start += @in_ndim end if axis > start axis += 1 else start -= 1 end Rollaxis.new(start, axis).apply(gy) end |
#forward(inputs) ⇒ Object
29 30 31 32 33 34 |
# File 'lib/chainer/functions/array/rollaxis.rb', line 29 def forward(inputs) retain_inputs([]) @in_ndim = inputs.first.ndim [Chainer::Utils::Array.rollaxis(inputs.first, @axis, start: @start)] end |