Understanding Batch Normalization
In the previous post, we talked about Convolution and Pooling layers. Stacking a large number of these layers (CNNs with activation functions and pooling) results in a Deep CNN architecture, which is often hard to train. It becomes very difficult to converge once they become very deep. The most common solution to this problem is Batch Normalization.
The idea is to normalize the outputs of a layer so that they have zero mean and unit variance. If you ask why?. The distribution of the inputs to layers deep in the network may change after each mini-batch when the weights are updated. This can cause the learning algorithm to forever chase a moving target. This change in the distribution of inputs to layers in the network is referred to as the technical name "internal covariate shift" and batch normalization helps to reduce this internal covariance shift, improving optimization.
We introduce batch normalization as a layer in our network that takes in inputs and normalizes them. \begin{align*} & \hat{x}_{ij} = \frac{x_{ij} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}} \\ & y_{ij} = \gamma_j \hat{x}_{ij} + \beta_j \end{align*} Fo each feature/channel, $\mu_j$ and $\sigma_j$ are the mean and standard deviation across the mini-batch of inputs. Having zero mean and unit variance is a too hard constraint for our network so we have a scale and shift parameters $\gamma$ and $\beta$ that can be learned during training.
A nice side-effect of changing the mean and variance of each channel (i.e. introducing randomness) is that it helps in regularization, enhancing the generalization ability of the networks.
However, since we are computing $\mu_j$ and $\sigma_j$ across the mini-batch of inputs, these estimates depend upon the mini-batch itself. So during test time, if I have a mini-batch of say [cat, dog, frog] and another mini-batch of [cat, car, horse], and both of them have the same cat image, the outputs would be different for each one because the model would compute two different means and standard deviation for both batches. The output that our model produces for each input element of a batch depends upon every other element in the batch.
Therefore in batch normalization, we make sure that our model behaves differently during training and testing times. During test time, the batch norm layer would not compute the mean and standard deviation over the batch, instead, we will use the running average of means and standard deviations generated during training. So during test time, the batch norm becomes a linear operation (as mean and standard deviations are constants) and can easily be fused with the previous linear or convolution layer. Therefore, the batch norm layer is often inserted after fully connected or convolutional layers, and before nonlinearity.
To summarize, batch-norm- Makes deep networks much easier to train!
- Allows higher learning rates, faster convergence
- Networks become more robust to initialization
- Acts as regularization during training
- Zero overhead at test-time: can be fused with conv
One of the irritating features of batch-norm is that it behaves differently during training and test time. One variant of batch-norm is called Layer Normalization which has the same behavior during training and testing. The only difference is that rather than averaging over batch dimensions, we average over the feature dimension. So, the estimates no longer depend on the elements of the batch and we have per-channel mean and standard deviation. Layer normalization is commonly used in RNNs and Transformers.
Another variant of this type of layer is called Instance Normalization, where we average over the spatial dimensions of the image. So, during training, we would have a per-image mean and per-image standard deviation, and this type of normalization would also behave the same during training and testing.
We also have one more variant to it, called the Group Normalization which is generally used for group convolutions, where we split the channel into some number of groups and then normalize over these subsets of the channel dimension. Such a type of normalization tends to work quite well in some applications such as object detection.
The figure below gives an intuition on these four types of normalizations,
Comments
Post a Comment