Module: Chainer::Initializers

Defined in:
lib/chainer/initializers/init.rb,
lib/chainer/initializers/normal.rb,
lib/chainer/initializers/uniform.rb,
lib/chainer/initializers/constant.rb

Defined Under Namespace

Classes: Constant, HeNormal, Normal, Uniform

Class Method Summary collapse

Class Method Details

.generate_array(initializer, shape, device: Chainer::Device.default) ⇒ Object



3
4
5
6
7
8
9
10
# File 'lib/chainer/initializers/init.rb', line 3

def self.generate_array(initializer, shape, device: Chainer::Device.default)
  klass = device.xm::SFloat
  if initializer.respond_to?(:dtype) && initializer.dtype
    klass = initializer.dtype
  end
  array = klass.new(shape).rand
  initializer.(array)
end

.get_initializer(initializer, device: Chainer::Device.default) ⇒ Object



12
13
14
15
16
17
18
19
20
21
22
# File 'lib/chainer/initializers/init.rb', line 12

def self.get_initializer(initializer, device: Chainer::Device.default)
  return HeNormal.new(scale: 1 / device.xm::NMath.sqrt(2)) if initializer.nil?
  return Constant.new(initializer) if initializer.kind_of?(Numeric)
  return Constant.new(initializer) if Chainer.array?(initializer)

  unless initializer.respond_to?(:call)
    raise TypeError, "invalid type of initializer: #{initializer.class}"
  end

  return initializer
end

.nan(dtype: nil) ⇒ Object



24
25
26
# File 'lib/chainer/initializers/init.rb', line 24

def self.nan(dtype: nil)
  Constant.new(Float::NAN, dtype: dtype)
end