Module: Chainer::Utils::Initializer

Defined in:
lib/chainer/utils/initializer.rb

Class Method Summary collapse

Class Method Details

.get_fans(shape, device: Chainer::Device.default) ⇒ Object



4
5
6
7
8
9
10
11
# File 'lib/chainer/utils/initializer.rb', line 4

def self.get_fans(shape, device: Chainer::Device.default)
  raise 'shape must be of length >= 2: shape={}' if shape.size < 2
  slice_arr = shape.slice(2, shape.size)
  receptive_field_size = slice_arr.empty? ? 1 : device.xm::Int32[slice_arr].prod
  fan_in = shape[1] * receptive_field_size
  fan_out = shape[0] * receptive_field_size
  [fan_in, fan_out]
end