Class: Chainer::Functions::Array::BroadcastTo
Overview
Function that broadcasts an array to a new shape.
Instance Attribute Summary
#inputs, #outputs, #rank
Class Method Summary
collapse
Instance Method Summary
collapse
#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain
Constructor Details
6
7
8
|
# File 'lib/chainer/functions/array/broadcast_to.rb', line 6
def initialize(shape)
@shape = shape
end
|
Class Method Details
.broadcast_to(x, shape) ⇒ Object
10
11
12
13
|
# File 'lib/chainer/functions/array/broadcast_to.rb', line 10
def self.broadcast_to(x, shape)
return Chainer::Variable.as_variable(x) if x.shape == shape
self.new(shape).apply([x]).first
end
|
Instance Method Details
#backward(indexes, grad_outputs) ⇒ Object
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
# File 'lib/chainer/functions/array/broadcast_to.rb', line 20
def backward(indexes, grad_outputs)
gx = grad_outputs.first
shape = @inputs.first.shape
ndim = shape.size
lead = gx.ndim - ndim
lead_axis = lead.times.to_a
axis = shape.each_with_object([]).with_index do |(sx, res), i|
next unless sx == 1
res << i + lead
end
gx = Chainer::Functions::Math::Sum.sum(gx, axis: lead_axis + axis, keepdims: true)
return [Chainer::Functions::Array::Squeeze.squeeze(gx, axis: lead_axis)] if lead > 0
[gx]
end
|
#forward(inputs) ⇒ Object
15
16
17
18
|
# File 'lib/chainer/functions/array/broadcast_to.rb', line 15
def forward(inputs)
x = inputs.first
[Chainer::Utils::Array.broadcast_to(x, @shape)]
end
|