Batch Normalization论文总结
作者:互联网
Batch Normalization要解决的问题
训练深度神经网络是复杂的,因为在训练过程中,每一层参数的更新变化,都会影响到下一层输入的分布,而且随着网络深度的增加,这种影响会不断放大。每一层输入分布的变化就迫使每一层要不断适应新分布,所以受到网络内部分布变化的影响,
1.训练网络的学习率不能太大,这就减慢了网络的训练速度;
2.需要谨慎初始化模型参数;
3.容易使非线性函数(sigmoid函数)达到饱和区域。sigmoid函数g(x)=1+exp(−x)1,函数如下图所示。
由于x受到w,b以及之前所有层的参数的影响,在训练过程中这些参数的变化可能会使x的许多维进入函数的饱和区域,使得这些维上的梯度为0(梯度消失),减缓收敛速度。
文章中将内部分布变化这一现象称为内部协变量变换(internal covariate shift),而解决这一问题的办法就是标准化(normalize)每一层的输入,让标准化作为模型的一部分,使得整个网络流过的数据都是同分布的,并且标准化是在每一个mini-batch上进行的,这也是Batch Normalization名字的由来。(mini-batch的优势:首先,loss在mini-batch上的梯度是对loss在整个训练集上的梯度的估计,batch越大,估计越准确,效果越好;第二,由于并行计算,mini-batch的效率高。)
Batch Normalization算法
对于一个d维输入x=(x(1)...x(d)),BN的操作是对其每一维进行标准化
x^(k)=Var[x(k)]x(k)−E[x(k)]
但是如此简单暴力的将输入的每一维限制在均值为0方差为1的同分布内,会破坏每一层的表达能力。例如BN层会将输入特征限制在非线性函数(如sigmoid)的线性部分,多个线性层叠加和单个线性层是一样的,显然会使网络的表达能力下降。所以文章添加了两个参数γ(k),β(k),x在标准化后,再用这两个参数进行平移缩放(对方差进行缩放scale,对均值进行平移shift),如下所示
y(k)=γ(k)x^(k)+β(k)
γ(k),β(k)是两个可学习的参数,用来恢复每层的表达能力,不再是单一的迫使每层同分布。BN算法如下图所示,图中的x是指每一维的特征x(k)。
这里也可以看出,每一层都会有一对参数γ,β;每一层也会计算出相应mini-batch的μB,σB2。这四个量都是多维向量,每一维对应输入向量的一维。训练过程中,每层γ,β都要更新,而且每层对应mini-batch的μB,σB2也在变动(随着参数更新,每层输入会变)。
反向传播更新参数的过程如下所示。
训练与预测时的Batch Normalization
训练和预测的算法如下所示。在训练的时候,文章中并没有将输入的每一维属性进行BN操作,而是选定了一个K大小的属性子集,然后在每个mini-batch上对这一子集进行BN操作,即Algorithm1(见上图)。在预测的时候,没有必要也不希望进行和训练时一样的操作,因为可能预测的时候只有一个样本,它的均值和方差是没有意义的。所以在预测时,对于每一层,我们将训练时的所有mini-batch对应的均值和方差取均值作为这一层的均值和方差,即
E(x)←EB[μB]
Var(x)←m−1mEB[σB2]
(这里是对σB2的无偏估计,即m−1mEB[σB2]=EB[m−1mσB2]=EB[m−1mm1i=1∑m(xi−μB)2]=EB[m−11i=1∑m(xi−μB)2])
最终y=Var[x]+ϵγ⋅x+(β−Var[x]+ϵγE[x])
与上述的神经网络对输入向量的每一维进行BN操作不同,对于卷积神经网络,BN操作的对象是以每一个特征图为单位的,即对输入特征图的每一个通道的特征图在mini-batch上求均值和方差,并对应一对参数γ,β。
Batch Normalization的优势
1.可以使用更大的学习率以及对参数初始化不用太小心谨慎
BN层会将每层输入强行拉回均值为0方差为1,或附近(γ,β平移缩放),对于某些激活函数(如sigmoid)这样避免了达到其饱和区域,从而不会由于梯度太小引起梯度消失或陷入局部最优解。
学习率过大,会使参数增长或下降过快,参数过大或者过小会使反向传播中梯度变大,从而导致梯度爆炸。而使用BN层会使得反向传播中的梯度对参数的大小不敏感。例如让参数扩大a倍,有
BN((aW)u)=γa2σ2+ϵaWu−aμ+β=BN(Wu)=γσ2+ϵWu−μ+β(由于ϵ很小可忽略),那么
∂u∂BN((aW)u)=∂u∂BN(Wu)
∂(aW)∂BN((aW)u)=a1∂W∂BN(Wu)
可以看到参数越大并不会使梯度增大,反而使梯度更小。
2.有正则化的效果
BN层每次以一个mini-batch为单位进行操作,减少了单个数据尤其是异常数据对网络的影响,从而减少过拟合的风险。
论文地址:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
标签:space,dfrac,BN,论文,Batch,Wu,B2,Var,Normalization 来源: https://blog.csdn.net/ysl_ysl123/article/details/94194969