Batch Normalization
Normalizing activation in a neural network
In the rise of deep learning, one of the most important ideas has been an algorithm called batch normalization. It makes your hyperparameter search problem much easier & makes your neural network much more robust.
When training a model, normalizing the input features can speed up learnings. This can turn your learning problem from something that's elongated to something rounder, easier for an algorithm to optimize.
Now how about a deeper model? Wouldn't it be nice if you can normalize the mean and variances of a hidden unit, e.g a2, to make the training of w3 and b3 more efficient?
This is what batch normalization does, though, instead of normalizing a2, we normalize z2.
Given some intermediate values in NN z(1) .... z(m)miu = 1/m * sum(z) <- compute meanstd_deviation^2 = 1/m * sum((z - miu)^2) <- compute std deviationz_norm(i) = (z(i) - miu) / sqrt(std_deviation^2 + epsilon)z_tilda(i) = gamma * z_norm(i) + beta
After that, we use z_tilda(i) instead of z(i) to feed the next layer. gamma
and beta
are learnable parameters of the
model. Moreover, if gamma = sqrt(std_deviation^2 + epsilon)
and beta = miu
, then z_tilda(i) = z(i)
and the effect
of gamma
and beta
is cancelled.
What it really does is it normalizes in mean and variance of these hidden unit values, really the zis, to have some fixed mean and variance. And that mean and variance could be 0 and 1, or it could be some other value, and it's controlled by these parameters gamma and beta. It is because, there are times we may not want it to be between 0 and 1, e.g when we will pass these z_tilda values to sigmoid activation function.
Fitting batch norm into a neural network
Gamma and beta here has no correlation with Adam optimization or RMS Prop algorithm. Moreover, usually, there is no need for us to implement batch norm by ourselves, since it is provided by ML Library.
During the forward phase, the process is similar. However, instead of passing z[l] directly to the activation function to compute a[l], we compute z_tilda[i] and pass it to the a[i].
During the back propagation phase, we can use gradient descent, RMS Prop, or adam optimization to update the gamma and beta on each hidden layer.
Initially z[l] = w[l] * a[l-1] + b[l]
, however, due to the batch norm, the effect of b[l]
can be cancelled. Thus, we can
omit b[l]
and the computation of db[l]
. Thus, the formula became z[l] = w[l] * a[l-1]
.
Why does Batch Norm work?
Normalizing input features can speed up learning, batch norm is doing a similar thing.
It makes weight on deeper layer more robust to changes in the earlier layers of the network.
If we have a network trained to recognize B/W cat, and test it on colored cat picture, it may not work even though they have same function, we can't expect the learning algo to find the function just by looking at a subset of data. This is called Covariance Shift. If the distribution of data changes, we need to retrain even if the ground truth function is the same.
Why this affect NN?
In the perspective of the 3rd hidden layer, it gets a1[2], a2[2] , ... , a4[2] and find a way to map them to y_hat. This is true even for the 4th and 5th hidden layer that they need to learn those parameter to do a good job.
On the other side, the network is also adapting parameters w[1], b[1] and w[2], b[2]. So as these parameter chnages, the values in a2 will also changes. So in the perspective of the 3rd hidden lyaer, these hidden unit values are changing all the time and so it suffers from the Covariance Shift problem.
What batch norm does is to educes the amount that the distribution of these hidden unit values shifts around. Indeed the values z2 can changes, however no matter how they change, the distribution of the mean and variance stays the same. So, batch norm reduces the problem of the input values changing, it really causes these values to become more stable, so that the later layers of the neural network has more firm ground to stand on.
Thus it allows one layer to learn independently of the other layer and speed up the learning of the whole network.
Especially if used with mini batch, it has a slight regularization effect.
Each mini-batch is scaled by the mean/variance computed on just that mini batch. This add some noise to the values z[l] within than mini batch. Thus, similar to dropout, it adds some noise to the hidden layer's activation.
Generally, if the mini batch size increases, the regularization effect. However, we shall not use batch norm to regularize as it is not the purpose of it.
Batch norm at test time
Batch norm processes your data one mini batch at a time, but the test time you might need to process the examples one at a time. Let's see how you can adapt your network to do that. To compute the mean and variance of one example isn't a good idea. What we can do is to take the running average of exponentially weighted averages over the mean and variance.
Both processes works well, that we shouldn't worry much about it. When we use a framework. they will have a default way to estimate mean and variance.