Class: Chainer::Functions::Math::Sum

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/math/sum.rb

Overview

Sum of array elements over a given axis.

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(axis: nil, keepdims: false) ⇒ Sum

Returns a new instance of Sum.


16
17
18
19
20
21
22
23
24
25
26
27
28
29
# File 'lib/chainer/functions/math/sum.rb', line 16

def initialize(axis: nil, keepdims: false)
  if axis.nil?
    @axis = nil
  elsif axis.is_a?(Integer)
    @axis = [axis]
  elsif axis.is_a?(::Array) && axis.all? { |e| e.is_a?(Integer) }
    raise ArgumentError, "duplicate value in axis: #{axis}" unless axis.uniq.size == axis.size
    @axis = axis
  else
    raise TypeError, 'nil, Integer or Array of int are required'
  end

  @keepdims = keepdims
end

Class Method Details

.sum(x, axis: nil, keepdims: false) ⇒ Chainer::Variable

Sum of array elements over a given axis

@param keepdims If ‘true`, the specified axes are remained as axes of length one

Parameters:

  • x (Chainer::Variable)

    Elements to sum

  • axis (nil, Integer, Array<Integer>) (defaults to: nil)

    Axis which a sum is performed

Returns:


12
13
14
# File 'lib/chainer/functions/math/sum.rb', line 12

def self.sum(x, axis: nil, keepdims: false)
  Sum.new(axis: axis, keepdims: keepdims).apply([x]).first
end

Instance Method Details

#backward(indexes, grad_outputs) ⇒ Object


38
39
40
41
42
43
44
45
46
47
48
# File 'lib/chainer/functions/math/sum.rb', line 38

def backward(indexes, grad_outputs)
  gy = grad_outputs.first
  ndim = @inputs.first.shape.size
  unless ndim == 0 || @axis.nil? || @keepdims
    actual_axis = @axis.map { |axis| axis >= 0 ? axis : axis + ndim  }
    shape = gy.shape
    actual_axis.sort.each { |axis| shape.insert(axis, 1) }
    gy = Chainer::Functions::Array::Reshape.reshape(gy, shape)
  end
  [Chainer::Functions::Array::BroadcastTo.broadcast_to(gy, @inputs.first.shape)]
end

#forward(inputs) ⇒ Object


31
32
33
34
35
36
# File 'lib/chainer/functions/math/sum.rb', line 31

def forward(inputs)
  x = inputs.first
  ret = x.sum(axis: @axis, keepdims: @keepdims)
  ret = x.class.cast(ret)
  [ret]
end