Class: Chainer::Functions::Array::Squeeze

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/array/squeeze.rb

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(axis: nil) ⇒ Squeeze

Returns a new instance of Squeeze.



15
16
17
18
19
20
21
22
23
24
25
# File 'lib/chainer/functions/array/squeeze.rb', line 15

def initialize(axis: nil)
  if axis.nil?
    @axis = nil
  elsif axis.kind_of?(Integer)
    @axis = [axis]
  elsif axis.kind_of?(::Array) && Array(axis).all? { |i| i.kind_of?(Integer) }
    @axis = axis
  else
    raise TypeError, 'axis must be None, int or tuple of ints'
  end
end

Class Method Details

.squeeze(x, axis: nil) ⇒ Chainer::Variable

Remove demensions of size one from the shape of a Numo::NArray.

Parameters:

  • x (Chainer::Variable or Numo::NArray)

    Input data.

  • axis (nil or integer or array of integer) (defaults to: nil)

    A subset of the single-dimensional entries in the shape to remove. If ‘nil` is supplied, all of them are removed. The dimension index starts at zero. If an axis with dimension greater than one is selected, an error is raised.

Returns:



11
12
13
# File 'lib/chainer/functions/array/squeeze.rb', line 11

def self.squeeze(x, axis: nil)
  self.new(axis: axis).apply([x]).first
end

Instance Method Details

#backward(indexes, grad_outputs) ⇒ Object



47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# File 'lib/chainer/functions/array/squeeze.rb', line 47

def backward(indexes, grad_outputs)
  if @axis.nil?
    axis = argone(@inputs[0].shape)
  else
    axis = @axis
    ndim = @inputs[0].shape.size
    axis = axis.map { |x| x < 0 ? x + ndim : x }
    axis.sort!
  end
  gx = grad_outputs.first

  shape = gx.shape
  axis.each do |x|
    shape.insert(x, 1)
  end
  [gx.reshape(*shape)]
end

#forward(inputs) ⇒ Object



27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# File 'lib/chainer/functions/array/squeeze.rb', line 27

def forward(inputs)
  x = inputs.first
  shape = x.shape

  # TODO: numpy.squeeze
  if @axis.nil?
    new_shape = shape.reject { |axis| axis == 1 }
  else
    new_shape = shape
    @axis.map do |a|
      raise StandardError, "cannot select an axis to squeeze out which has size not equal to one" unless shape[a] == 1
      new_shape[a] = nil
    end
    new_shape.compact!
  end
  ret = new_shape.size.zero? ? x.class.new.fill(x[0]) : x.reshape(*new_shape)

  [ret]
end