Class: Chainer::Functions::Normalization::FixedBatchNormalization

Inherits:
Chainer::FunctionNode show all
Includes:
Calculation
Defined in:
lib/chainer/functions/normalization/batch_normalization.rb

Instance Attribute Summary collapse

Attributes inherited from Chainer::FunctionNode

#inputs, #outputs, #rank

Class Method Summary collapse

Instance Method Summary collapse

Methods included from Calculation

#apply_bn_fwd, #x_hat, #zero_if_none

Methods inherited from Chainer::FunctionNode

#apply, #backward_accumulate, #forward_cpu, #get_retained_inputs, #get_retained_outputs, #label, #output_data, #retain_inputs, #retain_outputs, #unchain

Constructor Details

#initialize(eps: 2e-5) ⇒ FixedBatchNormalization

Returns a new instance of FixedBatchNormalization.



176
177
178
179
180
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 176

def initialize(eps: 2e-5)
  @inv_std = nil
  @inv_var = nil
  @eps = eps
end

Instance Attribute Details

#inv_varObject (readonly)

Returns the value of attribute inv_var.



170
171
172
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 170

def inv_var
  @inv_var
end

Class Method Details

.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5) ⇒ Object



172
173
174
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 172

def self.fixed_batch_normalization(x, gamma, beta, mean, var, eps: 2e-5)
  FixedBatchNormalization.new(eps: eps).apply([x, gamma, beta, mean, var]).first
end

Instance Method Details

#backward(indexes, grad_outputs) ⇒ Object



209
210
211
212
213
214
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 209

def backward(indexes, grad_outputs)
  x, gamma, mean, var = get_retained_inputs
  gy, = grad_outputs
  f = FixedBatchNormalizationGrad.new(@eps, @expander, @axis, @inv_std, @inv_var)
  f.(x, gamma, mean, var, gy)
end

#forward(inputs) ⇒ Object



182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 182

def forward(inputs)
  retain_inputs([0, 1, 3, 4])
  x, gamma, beta, mean, var = inputs
  xp = Chainer.get_array_module(x)

  # expander inserts singleton dimensions to gamma and beta so that they
  # can be broadcasted with x.
  head_ndim = gamma.ndim + 1
  # TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
  suffix = [1] * (x.ndim - head_ndim)
  expander = -> (arr) do
    shape = [1] + arr.shape + suffix
    arr.reshape(*shape)
  end
  @expander = expander
  @axis = [0] + (head_ndim...(x.ndim)).to_a

  gamma = expander.(gamma)
  beta = expander.(beta)
  var += @eps
  @inv_var = var.reciprocal
  @inv_std = xp::NMath.sqrt(@inv_var)

  y = apply_bn_fwd(xp, x, expander.(mean), expander.(@inv_std), gamma, beta)
  [y]
end