Class: Chainer::Functions::Array::Reshape

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/array/reshape.rb

Overview

Reshapes an input array without copy.

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(shape) ⇒ Reshape

Returns a new instance of Reshape.


6
7
8
# File 'lib/chainer/functions/array/reshape.rb', line 6

def initialize(shape)
  @shape = shape
end

Class Method Details

.reshape(x, shape) ⇒ Object


10
11
12
13
# File 'lib/chainer/functions/array/reshape.rb', line 10

def self.reshape(x, shape)
  return Chainer::Variable.as_variable(x) if x.shape == shape
  return self.new(shape).apply([x]).first
end

Instance Method Details

#backward(indexes, grad_outputs) ⇒ Object


21
22
23
24
# File 'lib/chainer/functions/array/reshape.rb', line 21

def backward(indexes, grad_outputs)
  gx = grad_outputs.first
  [Reshape.reshape(gx, @inputs.first.shape)]
end

#forward(inputs) ⇒ Object


15
16
17
18
19
# File 'lib/chainer/functions/array/reshape.rb', line 15

def forward(inputs)
  x = inputs.first
  new_shape = @shape.map { |s| s == -1 ? nil : s }
  [x.reshape(*new_shape)]
end