Class: Chainer::Links::Normalization::BatchNormalization
- Inherits:
-
Chainer::Link
- Object
- Chainer::Link
- Chainer::Links::Normalization::BatchNormalization
- Defined in:
- lib/chainer/links/normalization/batch_normalization.rb
Instance Attribute Summary
Attributes inherited from Chainer::Link
Instance Method Summary collapse
-
#call(x, finetune: false) ⇒ Object
Invokes the forward propagation of BatchNormalization.
-
#initialize(size, decay: 0.9, eps: 2e-5, dtype: nil, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil) ⇒ BatchNormalization
constructor
Batch normalization layer on outputs of linear or convolution functions.
-
#start_finetuning ⇒ Object
Resets the population count for collecting population statistics.
Methods inherited from Chainer::Link
#cleargrads, #del_attr, #init_scope, #namedlinks, #namedparams, #params, #register_persistent, #serialize, #set_attr, #within_init_scope
Constructor Details
#initialize(size, decay: 0.9, eps: 2e-5, dtype: nil, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil) ⇒ BatchNormalization
Batch normalization layer on outputs of linear or convolution functions.
It runs in three modes: training mode, fine-tuning mode, and testing mode. In training mode, it normalizes the input by *batch statistics*. It also maintains approximated population statistics by moving averages, which can be used for instant evaluation in testing mode.
In fine-tuning mode, it accumulates the input to compute *population statistics*. In order to correctly compute the population statistics, a user must use this mode to feed mini-batches running through whole training dataset.
In testing mode, it uses pre-computed population statistics to normalize the input variable. The population statistics is approximated if it is computed by training mode, or accurate if it is correctly computed by fine-tuning mode.
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# File 'lib/chainer/links/normalization/batch_normalization.rb', line 26 def initialize(size, decay: 0.9, eps: 2e-5, dtype: nil, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil) super() dtype ||= Chainer::Device.default.xm::SFloat @avg_mean = dtype.zeros(size) register_persistent('avg_mean') @avg_var = dtype.zeros(size) register_persistent('avg_var') @n = 0 register_persistent('n') @decay = decay @eps = eps init_scope do if use_gamma initial_gamma = 1 if initial_gamma.nil? initial_gamma = Chainer::Initializers.get_initializer(initial_gamma) initial_gamma.dtype = dtype @gamma = Chainer::Parameter.new(initializer: initial_gamma, shape: size) end if use_beta initial_beta = 0 if initial_beta.nil? initial_beta = Chainer::Initializers.get_initializer(initial_beta) initial_beta.dtype = dtype @beta = Chainer::Parameter.new(initializer: initial_beta, shape: size) end end end |
Instance Method Details
#call(x, finetune: false) ⇒ Object
Invokes the forward propagation of BatchNormalization. In training mode, the BatchNormalization computes moving averages of mean and variance for evaluatino during training, and normalizes the input using batch statistics. BatchNormalization runs in fine-tuning mode; it accumulates the input array to compute population statistics for normalization, and normalizes the input using batch statistics.
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
# File 'lib/chainer/links/normalization/batch_normalization.rb', line 62 def call(x, finetune: false) if self.instance_variable_defined?(:@gamma) gamma = @gamma else gamma = Chainer::Variable.new(x.data.class.ones(@avg_mean.shape)) end if self.instance_variable_defined?(:@beta) beta = @beta else beta = Chainer::Variable.new(x.data.class.zeros(*@avg_mean.shape)) end if Chainer.configuration.train if finetune @n += 1 decay = 1.0 - 1.0 / @n else decay = @decay end ret = Chainer::Functions::Normalization::BatchNormalization.batch_normalization(x, gamma, beta, eps: @eps, running_mean: @avg_mean, running_var: @avg_var, decay: decay) else mean = Chainer::Variable.new(@avg_mean) var = Chainer::Variable.new(@avg_var) ret = Chainer::Functions::Normalization::FixedBatchNormalization.fixed_batch_normalization(x, gamma, beta, mean, var, eps: @eps) end ret end |
#start_finetuning ⇒ Object
Resets the population count for collecting population statistics. This method can be skipped if it is the first time to use the fine-tuning mode. Otherwise, this method should be called before starting the fine-tuning mode again.
96 97 98 |
# File 'lib/chainer/links/normalization/batch_normalization.rb', line 96 def start_finetuning @n = 0 end |