Class: Chainer::Functions::Math::Sum
- Inherits:
-
Chainer::FunctionNode
- Object
- Chainer::FunctionNode
- Chainer::Functions::Math::Sum
- Defined in:
- lib/chainer/functions/math/sum.rb
Overview
Sum of array elements over a given axis.
Instance Attribute Summary
Attributes inherited from Chainer::FunctionNode
Class Method Summary collapse
-
.sum(x, axis: nil, keepdims: false) ⇒ Chainer::Variable
Sum of array elements over a given axis.
Instance Method Summary collapse
- #backward(indexes, grad_outputs) ⇒ Object
- #forward(inputs) ⇒ Object
-
#initialize(axis: nil, keepdims: false) ⇒ Sum
constructor
A new instance of Sum.
Methods inherited from Chainer::FunctionNode
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Constructor Details
#initialize(axis: nil, keepdims: false) ⇒ Sum
Returns a new instance of Sum.
16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# File 'lib/chainer/functions/math/sum.rb', line 16 def initialize(axis: nil, keepdims: false) if axis.nil? @axis = nil elsif axis.is_a?(Integer) @axis = [axis] elsif axis.is_a?(::Array) && axis.all? { |e| e.is_a?(Integer) } raise ArgumentError, "duplicate value in axis: #{axis}" unless axis.uniq.size == axis.size @axis = axis else raise TypeError, 'nil, Integer or Array of int are required' end @keepdims = keepdims end |
Class Method Details
.sum(x, axis: nil, keepdims: false) ⇒ Chainer::Variable
Sum of array elements over a given axis
@param keepdims If ‘true`, the specified axes are remained as axes of length one
12 13 14 |
# File 'lib/chainer/functions/math/sum.rb', line 12 def self.sum(x, axis: nil, keepdims: false) Sum.new(axis: axis, keepdims: keepdims).apply([x]).first end |
Instance Method Details
#backward(indexes, grad_outputs) ⇒ Object
38 39 40 41 42 43 44 45 46 47 48 |
# File 'lib/chainer/functions/math/sum.rb', line 38 def backward(indexes, grad_outputs) gy = grad_outputs.first ndim = @inputs.first.shape.size unless ndim == 0 || @axis.nil? || @keepdims actual_axis = @axis.map { |axis| axis >= 0 ? axis : axis + ndim } shape = gy.shape actual_axis.sort.each { |axis| shape.insert(axis, 1) } gy = Chainer::Functions::Array::Reshape.reshape(gy, shape) end [Chainer::Functions::Array::BroadcastTo.broadcast_to(gy, @inputs.first.shape)] end |
#forward(inputs) ⇒ Object
31 32 33 34 35 36 |
# File 'lib/chainer/functions/math/sum.rb', line 31 def forward(inputs) x = inputs.first ret = x.sum(axis: @axis, keepdims: @keepdims) ret = x.class.cast(ret) [ret] end |