Class: Chainer::Functions::Array::Rollaxis

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/array/rollaxis.rb

Overview

Roll axis of an array.

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(axis, start) ⇒ Rollaxis

Returns a new instance of Rollaxis.



16
17
18
19
20
21
22
23
24
25
26
27
# File 'lib/chainer/functions/array/rollaxis.rb', line 16

def initialize(axis, start)
  unless axis.is_a?(Integer)
    raise ArgumentError, 'axis must be int'
  end

  unless start.is_a?(Integer)
    raise ArgumentError, 'start must be int'
  end

  @axis = axis
  @start = start
end

Class Method Details

.rollaxis(x, axis, start: 0) ⇒ Chainer::Variable

Roll the axis backwards to the given position.

Parameters:

  • x (Chainer::Variable)

    Input variable

  • axis (Integer)

    The axis to roll backwards.

  • start (Integer) (defaults to: 0)

    The place to which the axis is moved.

Returns:



12
13
14
# File 'lib/chainer/functions/array/rollaxis.rb', line 12

def self.rollaxis(x, axis, start: 0)
  Rollaxis.new(axis, start).apply([x]).first
end

Instance Method Details

#backward(indexes, gy) ⇒ Object



36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/chainer/functions/array/rollaxis.rb', line 36

def backward(indexes, gy)
  axis = @axis
  if axis < 0
    axis += @in_ndim
  end
  start = @start
  if start < 0
    start += @in_ndim
  end

  if axis > start
    axis += 1
  else
    start -= 1
  end

  Rollaxis.new(start, axis).apply(gy)
end

#forward(inputs) ⇒ Object



29
30
31
32
33
34
# File 'lib/chainer/functions/array/rollaxis.rb', line 29

def forward(inputs)
  retain_inputs([])
  @in_ndim = inputs.first.ndim

  [Chainer::Utils::Array.rollaxis(inputs.first, @axis, start: @start)]
end