Class: Chainer::Functions::Array::Squeeze
- Inherits:
-
Chainer::FunctionNode
- Object
- Chainer::FunctionNode
- Chainer::Functions::Array::Squeeze
- Defined in:
- lib/chainer/functions/array/squeeze.rb
Instance Attribute Summary
Attributes inherited from Chainer::FunctionNode
Class Method Summary collapse
-
.squeeze(x, axis: nil) ⇒ Chainer::Variable
Remove demensions of size one from the shape of a Numo::NArray.
Instance Method Summary collapse
- #backward(indexes, grad_outputs) ⇒ Object
- #forward(inputs) ⇒ Object
-
#initialize(axis: nil) ⇒ Squeeze
constructor
A new instance of Squeeze.
Methods inherited from Chainer::FunctionNode
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Constructor Details
#initialize(axis: nil) ⇒ Squeeze
Returns a new instance of Squeeze.
15 16 17 18 19 20 21 22 23 24 25 |
# File 'lib/chainer/functions/array/squeeze.rb', line 15 def initialize(axis: nil) if axis.nil? @axis = nil elsif axis.kind_of?(Integer) @axis = [axis] elsif axis.kind_of?(::Array) && Array(axis).all? { |i| i.kind_of?(Integer) } @axis = axis else raise TypeError, 'axis must be None, int or tuple of ints' end end |
Class Method Details
.squeeze(x, axis: nil) ⇒ Chainer::Variable
Remove demensions of size one from the shape of a Numo::NArray.
11 12 13 |
# File 'lib/chainer/functions/array/squeeze.rb', line 11 def self.squeeze(x, axis: nil) self.new(axis: axis).apply([x]).first end |
Instance Method Details
#backward(indexes, grad_outputs) ⇒ Object
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# File 'lib/chainer/functions/array/squeeze.rb', line 47 def backward(indexes, grad_outputs) if @axis.nil? axis = argone(@inputs[0].shape) else axis = @axis ndim = @inputs[0].shape.size axis = axis.map { |x| x < 0 ? x + ndim : x } axis.sort! end gx = grad_outputs.first shape = gx.shape axis.each do |x| shape.insert(x, 1) end [gx.reshape(*shape)] end |
#forward(inputs) ⇒ Object
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# File 'lib/chainer/functions/array/squeeze.rb', line 27 def forward(inputs) x = inputs.first shape = x.shape # TODO: numpy.squeeze if @axis.nil? new_shape = shape.reject { |axis| axis == 1 } else new_shape = shape @axis.map do |a| raise StandardError, "cannot select an axis to squeeze out which has size not equal to one" unless shape[a] == 1 new_shape[a] = nil end new_shape.compact! end ret = new_shape.size.zero? ? x.class.new.fill(x[0]) : x.reshape(*new_shape) [ret] end |