Class: Chainer::Functions::Normalization::FixedBatchNormalizationGrad

Inherits:
Chainer::Function
  • Object
show all
Includes:
Calculation
Defined in:
lib/chainer/functions/normalization/batch_normalization.rb

Instance Attribute Summary

Attributes inherited from Chainer::Function

#inputs, #output_data, #outputs, #owned_node, #rank, #retain_after_backward

Instance Method Summary collapse

Methods included from Calculation

#apply_bn_fwd, #x_hat, #zero_if_none

Methods inherited from Chainer::Function

#backward_cpu, #backward_gpu, #call, #forward_cpu, #forward_gpu, #label, #node, #retain_inputs, #retain_outputs

Constructor Details

#initialize(eps, expander, axis, inv_std, inv_var) ⇒ FixedBatchNormalizationGrad

Returns a new instance of FixedBatchNormalizationGrad.



220
221
222
223
224
225
226
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 220

def initialize(eps, expander, axis, inv_std, inv_var)
  @eps = eps
  @expander = expander
  @axis = axis
  @inv_std = inv_std
  @inv_var = inv_var
end

Instance Method Details

#backward(inputs, grad_outputs) ⇒ Object



252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 252

def backward(inputs, grad_outputs)
  x, gamma, mean, _, gy = inputs
  ggx1, gggamma1, ggbeta1, ggmean1, ggvar1 = grad_outputs
  gx1, ggamma1, gbeta1, gmean1, gvar1 = output_data

  # Handle None in output gradients.
  xp = Chainer.get_array_module(x)
  ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
  gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
  ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
  ggmean1 = zero_if_none(xp, ggmean1, mean.shape, mean.class)
  ggvar1 = zero_if_none(xp, ggvar1, mean.shape, mean.class)

  expander = @expander
  x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
  tmp = -0.5 * ggvar1

  gamma_over_var = gamma * @inv_var
  g_gamma_over_var = tmp * ggamma1

  gggamma2 = gggamma1 + tmp * gamma_over_var
  gx_hat = gy * expander.(gggamma2)
  gx2 = expander.(@inv_std) * gx_hat
  gmean2 = -@inv_std * gx_hat.sum(axis: @axis)

  g_gamma_over_std = (ggx1 * gy).sum(axis: @axis) - ggmean1 * gbeta1
  ggbeta2 = ggbeta1 - ggmean1 * @gamma_over_std
  ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(@gamma_over_std) * ggx1)

  ggamma2 = (@inv_var * g_gamma_over_var + @inv_std * g_gamma_over_std)
  gvar2 = -(ggamma2 * gamma_over_var + 0.5 * @inv_var * ((x_hat * gx_hat).sum(axis: @axis) - @gamma_over_std * g_gamma_over_std))

  [gx2, ggamma2, gmean2, gvar2, ggy2]
end

#forward(inputs) ⇒ Object



228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# File 'lib/chainer/functions/normalization/batch_normalization.rb', line 228

def forward(inputs)
  retain_inputs([0, 1, 2, 4])
  x, gamma, mean, var, gy = inputs
  expander = @expander
  xp = Chainer.get_array_module(x)

  if @inv_std.nil? || @inv_var.nil?
    @inv_var = (var + @eps).reciprocal
    @inv_std = xp::NMath.sqrt(@inv_var)
  end

  @gamma_over_std = gamma * @inv_std
  x_hat = x_hat(x, expander.(mean), expander.(@inv_std))

  gx = expander.(@gamma_over_std) * gy
  gbeta = gy.sum(axis: @axis)
  ggamma = (x_hat * gy).sum(axis: @axis)
  gmean = -@gamma_over_std * gbeta
  gvar = -0.5 * gamma * @inv_var * ggamma

  retain_outputs([0, 1, 2, 3, 4])
  [gx, ggamma, gbeta, gmean, gvar]
end