python神经网络Batch Normalization底层原理详解

作者:Bubbliiiing 时间:2021-01-28 12:50:08 

什么是Batch Normalization

Batch Normalization是神经网络中常用的层,解决了很多深度学习中遇到的问题,我们一起来学习一哈。

Batch Normalization是由google提出的一种训练优化方法。参考论文:Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift。

Batch Normalization的名称为批标准化,它的功能是使得输入的X数据符合同一分布,从而使得训练更加简单、快速。

一般来讲,Batch Normalization会放在卷积层后面,即卷积 + 标准化 + 激活函数。

其计算过程可以简单归纳为以下3点:

1、求数据均值。

2、求数据方差。

3、数据进行标准化。

Batch Normalization的计算公式

Batch Normalization的计算公式主要看如下这幅图:

python神经网络Batch Normalization底层原理详解

这个公式一定要静下心来看,整个公式可以分为四行:

1、对输入进来的数据X进行均值求取。

2、利用输入进来的数据X减去第一步得到的均值,然后求平方和,获得输入X的方差。

3、利用输入X、第一步获得的均值和第二步获得的方差对数据进行归一化,即利用X减去均值,然后除上方差开根号。方差开根号前需要添加上一个极小值。

4、引入γ和β变量,对输入进来的数据进行缩放和平移。利用γ和β两个参数,让我们的网络可以学习恢复出原始网络所要学习的特征分布。

前三步是标准化工序,最后一步是反标准化工序。

Bn层的好处

1、加速网络的收敛速度。在神经网络中,存在内部协变量偏移的现象,如果每层的数据分布不同的话,会导致非常难收敛,如果把每层的数据都在转换在均值为零,方差为1的状态下,这样每层数据的分布都是一样的,训练会比较容易收敛。

2、防止梯度 * 和梯度消失。对于梯度消失而言,以Sigmoid函数为例,它会使得输出在[0,1]之间,实际上当x到了一定的大小,sigmoid激活函数的梯度值就变得非常小,不易训练。归一化数据的话,就能让梯度维持在比较大的值和变化率;

对于梯度 * 而言,在方向传播的过程中,每一层的梯度都是由上一层的梯度乘以本层的数据得到。如果归一化的话,数据均值都在0附近,很显然,每一层的梯度不会产生 * 的情况。

3、防止过拟合。在网络的训练中,Bn使得一个minibatch中所有样本都被关联在了一起,因此网络不会从某一个训练样本中生成确定的结果,这样就会使得整个网络不会朝这一个方向使劲学习。一定程度上避免了过拟合。

为什么要引入γ和β变量

Bn层在进行前三步后,会引入γ和β变量,对输入进来的数据进行缩放和平移。

γ和β变量是网络参数,是可学习的。

引入γ和β变量进行缩放平移可以使得神经网络有自适应的能力,在标准化效果好时,尽量不抵消标准化的作用,而在标准化效果不好时,尽量去抵消一部分标准化的效果,相当于让神经网络学会要不要标准化,如何折中选择。

Bn层的代码实现

Pytorch代码看起来比较简单,而且和上面的公式非常符合,可以学习一下,参考自

https://www.jb51.net/article/247197.htm

def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9):
   if not is_training:
       x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
   else:
       mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
       var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
       x_hat = (x - mean) / torch.sqrt(var + eps)
       moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
       moving_var = momentum * moving_var + (1.0 - momentum) * var
   Y = gamma * x_hat + beta
   return Y, moving_mean, moving_var
class BatchNorm2d(nn.Module):
   def __init__(self, num_features):
       super(BatchNorm2d, self).__init__()
       shape = (1, num_features, 1, 1)
       self.gamma = nn.Parameter(torch.ones(shape))
       self.beta = nn.Parameter(torch.zeros(shape))
       self.register_buffer('moving_mean', torch.zeros(shape))
       self.register_buffer('moving_var', torch.ones(shape))
   def forward(self, x):
       if self.moving_mean.device != x.device:
           self.moving_mean = self.moving_mean.to(x.device)
           self.moving_var = self.moving_var.to(x.device)
       y, self.moving_mean, self.moving_var = batch_norm(self.training,
           x, self.gamma, self.beta, self.moving_mean,
           self.moving_var, eps=1e-5, momentum=0.9)
       return y

来源:https://blog.csdn.net/weixin_44791964/article/details/114998793

标签:python,神经网络,Batch,Normalization,BatchNor
0
投稿

猜你喜欢

  • 利用Python制作一个愚人节整蛊消息框

    2022-08-07 22:06:53
  • 有序列表 li ol

    2008-07-30 12:31:00
  • Python元组定义及集合的使用

    2023-11-22 12:32:03
  • MySQL数据库安全解决方案

    2009-10-17 21:36:00
  • Python OpenCV去除字母后面的杂线操作

    2023-08-02 15:18:47
  • python中opencv图像叠加、图像融合、按位操作的具体实现

    2023-11-11 21:39:21
  • 使用Python+wxpy 找出微信里把你删除的好友实例

    2023-05-09 05:12:28
  • Python协程的2种实现方式分享

    2022-12-21 12:42:56
  • 用图片做网站输入验证的构想

    2009-02-02 10:18:00
  • Python+requests+unittest执行接口自动化测试详情

    2023-07-30 15:08:37
  • Python字符串hashlib加密模块使用案例

    2023-08-02 12:06:24
  • go语言csrf库使用实现原理示例解析

    2023-08-07 03:34:38
  • 谈谈Javascript中的++和–操作符

    2009-05-08 11:43:00
  • python爬取分析超级大乐透历史开奖数据第1/2页

    2021-03-15 17:02:59
  • ASP程序与SQL存储过程结合使用详解

    2011-03-25 10:50:00
  • ajax标签导航实例详解教程

    2008-02-01 10:54:00
  • Javascript 小游戏,“是男人坚持 100 次”

    2009-01-22 14:25:00
  • Python超简单容易上手的画图工具库(适合新手)

    2021-12-06 04:05:23
  • 参数传递解决window.open的session变量丢失

    2007-10-22 17:40:00
  • 网页切片算法的若干问题

    2008-04-17 13:10:00
  • asp之家 网络编程 m.aspxhome.com