Class: Chainer::Functions::Array::BroadcastTo
- Inherits:
-
Chainer::FunctionNode
- Object
- Chainer::FunctionNode
- Chainer::Functions::Array::BroadcastTo
- Defined in:
- lib/chainer/functions/array/broadcast_to.rb
Overview
Function that broadcasts an array to a new shape.
Instance Attribute Summary
Attributes inherited from Chainer::FunctionNode
Class Method Summary collapse
Instance Method Summary collapse
- #backward(indexes, grad_outputs) ⇒ Object
- #forward(inputs) ⇒ Object
-
#initialize(shape) ⇒ BroadcastTo
constructor
A new instance of BroadcastTo.
Methods inherited from Chainer::FunctionNode
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Constructor Details
#initialize(shape) ⇒ BroadcastTo
Returns a new instance of BroadcastTo.
6 7 8 |
# File 'lib/chainer/functions/array/broadcast_to.rb', line 6 def initialize(shape) @shape = shape end |
Class Method Details
.broadcast_to(x, shape) ⇒ Object
10 11 12 13 |
# File 'lib/chainer/functions/array/broadcast_to.rb', line 10 def self.broadcast_to(x, shape) return Chainer::Variable.as_variable(x) if x.shape == shape self.new(shape).apply([x]).first end |
Instance Method Details
#backward(indexes, grad_outputs) ⇒ Object
20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# File 'lib/chainer/functions/array/broadcast_to.rb', line 20 def backward(indexes, grad_outputs) gx = grad_outputs.first shape = @inputs.first.shape ndim = shape.size lead = gx.ndim - ndim lead_axis = lead.times.to_a axis = shape.each_with_object([]).with_index do |(sx, res), i| next unless sx == 1 res << i + lead end gx = Chainer::Functions::Math::Sum.sum(gx, axis: lead_axis + axis, keepdims: true) return [Chainer::Functions::Array::Squeeze.squeeze(gx, axis: lead_axis)] if lead > 0 [gx] end |
#forward(inputs) ⇒ Object
15 16 17 18 |
# File 'lib/chainer/functions/array/broadcast_to.rb', line 15 def forward(inputs) x = inputs.first [Chainer::Utils::Array.broadcast_to(x, @shape)] end |