Class: Chainer::Functions::Array::BroadcastTo

Inherits:
Chainer::FunctionNode show all
Defined in:
lib/chainer/functions/array/broadcast_to.rb

Overview

Function that broadcasts an array to a new shape.

Instance Attribute Summary

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(shape) ⇒ BroadcastTo

Returns a new instance of BroadcastTo.



6
7
8
# File 'lib/chainer/functions/array/broadcast_to.rb', line 6

def initialize(shape)
    @shape = shape
end

Class Method Details

.broadcast_to(x, shape) ⇒ Object



10
11
12
13
# File 'lib/chainer/functions/array/broadcast_to.rb', line 10

def self.broadcast_to(x, shape)
  return Chainer::Variable.as_variable(x) if x.shape == shape
  self.new(shape).apply([x]).first
end

Instance Method Details

#backward(indexes, grad_outputs) ⇒ Object



20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/chainer/functions/array/broadcast_to.rb', line 20

def backward(indexes, grad_outputs)
  gx = grad_outputs.first
  shape = @inputs.first.shape
  ndim = shape.size
  lead = gx.ndim - ndim
  lead_axis = lead.times.to_a
  axis = shape.each_with_object([]).with_index do |(sx, res), i|
    next unless sx == 1
    res << i + lead
  end
  gx = Chainer::Functions::Math::Sum.sum(gx, axis: lead_axis + axis, keepdims: true)
  return [Chainer::Functions::Array::Squeeze.squeeze(gx, axis: lead_axis)] if lead > 0
  [gx]
end

#forward(inputs) ⇒ Object



15
16
17
18
# File 'lib/chainer/functions/array/broadcast_to.rb', line 15

def forward(inputs)
  x = inputs.first
  [Chainer::Utils::Array.broadcast_to(x, @shape)]
end