Class: Chainer::Functions::Normalization::BatchNormalization

Inherits:
Chainer::FunctionNode show all
Includes:
Calculation
Defined in:
lib/chainer/functions/normalization/batch_normalization.rb

Instance Attribute Summary collapse

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Calculation

#apply_bn_fwd, #x_hat, #zero_if_none

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9) ⇒ BatchNormalization

Returns a new instance of BatchNormalization.



34
35
36
37
38
39
40
41
42
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 34

def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
  @mean = nil
  @inv_std = nil

  @running_mean = mean
  @running_var = var
  @eps = eps
  @decay = decay
end

Instance Attribute Details

#running_meanObject (readonly)

Returns the value of attribute running_mean.



28
29
30
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 28

def running_mean
  @running_mean
end

#running_varObject (readonly)

Returns the value of attribute running_var.



28
29
30
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 28

def running_var
  @running_var
end

Class Method Details

.batch_normalization(x, gamma, beta, eps: 2e-5, running_mean: nil, running_var: nil, decay: 0.9) ⇒ Object



30
31
32
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 30

def self.batch_normalization(x, gamma, beta, eps: 2e-5, running_mean: nil, running_var: nil, decay: 0.9)
  BatchNormalization.new(eps: eps, mean: running_mean, var: running_var, decay: decay).apply([x, gamma, beta])[0]
end

Instance Method Details

#backward(indexes, grad_outputs) ⇒ Object



88
89
90
91
92
93
94
95
96
97
98
99
100
101
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 88

def backward(indexes, grad_outputs)
  x, gamma = get_retained_inputs
  gy, = grad_outputs

  # hatappi debug
  #@mean = @mean.class.new(@mean.shape).seq
  #@inv_std = @inv_std.class.new(@inv_std.shape).seq
  #x.data = x.data.class.new(x.shape).seq
  #gamma.data = gamma.data.class.new(gamma.shape).seq
  #gy.data = gy.data.class.new(gy.shape).seq

  f = BatchNormalizationGrad.new(@eps, @expander, @axis, @mean, @inv_std)
  f.(x, gamma, gy)
end

#forward(inputs) ⇒ Object



44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 44

def forward(inputs)
  retain_inputs([0, 1])
  x, gamma, beta = inputs
  xp = Chainer.get_array_module(x)

  if @running_mean.nil?
    @running_mean = xp::NArray[*gamma].new_zeros
    @running_var = xp::NArray[*gamma].new_zeros
  end

  # expander inserts singleton dimensions to gamma and beta so that they
  # can be broadcasted with x.
  head_ndim = gamma.ndim + 1
  # TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
  suffix = [1] * (x.ndim - head_ndim)
  expander = -> (arr) do
    shape = [1] + arr.shape + suffix
    arr.reshape(*shape)
  end
  @expander = expander
  @axis = [0] + (head_ndim...(x.ndim)).to_a

  gamma = expander.(gamma)
  beta = expander.(beta)
  @mean = x.mean(axis: @axis)

  # TODO: Numo::Array can not be specified standard deviation
  var = ((x - x.mean(axis: @axis, keepdims: true)) ** 2).mean(axis: @axis)

  var += @eps
  @inv_std = var ** (-0.5)

  y = apply_bn_fwd(xp, x, expander.(@mean), expander.(@inv_std), gamma, beta)
  # Update running statistics
  m = x.size.div(gamma.size)
  adjust = m / [m - 1.0, 1.0].max
  @running_mean *= @decay
  @running_mean += (1 - @decay) * @mean
  @running_var *= @decay
  @running_var += (1 - @decay) * adjust * var

  [y]
end