2025年pytorch BatchNorm参数详解,计算过程

pytorch BatchNorm参数详解,计算过程目录 说明 BatchNorm1d 参数 num features eps momentum affine track running stats BatchNorm1d 训练时前向传播 BatchNorm1d 评估时前向传播 总结 说明 网络训练时和网络评估时 BatchNorm 模块的计算方式不同 如果一个网络里包含了 BatchNorm 则在训练时需要先调用 train

目录

说明

BatchNorm1d参数

num_features

eps

momentum

affine

track_running_stats

BatchNorm1d训练时前向传播

BatchNorm1d评估时前向传播

总结

----

说明

网络训练时和网络评估时,BatchNorm模块的计算方式不同。如果一个网络里包含了BatchNorm,则在训练时需要先调用train(),使网络里的BatchNorm模块的training=True(默认是True),在网络评估时,需要先调用eval(),使网络里的BatchNorm模块的training=False。

BatchNorm1d参数

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

num_features

输入维度是(N, C, L)时,num_features应该取C;这里N是batch size,C是数据的channel,L是数据长度。

输入维度是(N, L)时,num_features应该取L;这里N是batch size,L是数据长度,这时可以认为每条数据只有一个channel,省略了C

eps

对输入数据进行归一化时加在分母上,防止除零,详情见下文。

momentum

更新全局均值running_mean和方差running_var时使用该值进行平滑,详情见下文。

affine

设为True时,BatchNorm层才会学习参数

,否则不包含这两个变量,变量名是weight和bias,详情见下文。

track_running_stats

设为True时,BatchNorm层会统计全局均值running_mean和方差running_var,详情见下文。

BatchNorm1d训练时前向传播

首先对输入batch求和,并用这两个结果把batch归一化,使其均值为0,方差为1。归一化公式用到了eps(),即。如下输入内容,shape是(3, 4),即batch_size=3,此时num_features需要传入4。

如果==True,则使用momentum更新模块内部的(初值是[0., 0., 0., 0.])和(初值是[1., 1., 1., 1.]),更新公式是,其中代表更新后的和,表示更新前的和,表示当前batch的均值和无偏样本方差。

如果==False,则BatchNorm中不含有和两个变量。

如果==True,则对归一化后的batch进行仿射变换,即乘以模块内部的(初值是[1., 1., 1., 1.])然后加上模块内部的(初值是[0., 0., 0., 0.]),这两个变量会在反向传播时得到更新。

如果==False,则BatchNorm中不含有和两个变量,什么都都不做。

BatchNorm1d评估时前向传播

如果track_running_stats==True,则对batch进行归一化,公式为

,注意这里的均值和方差是running_mean和running_var,在网络训练时统计出来的全局均值和无偏样本方差。

如果track_running_stats==False,则对batch进行归一化,公式为

,注意这里的均值和方差是batch自己的mean和var,此时BatchNorm里不含有running_mean和running_var。注意此时使用的是无偏样本方差(和训练时不同),因此如果batch_size=1,会使分母为0,就报错了。

如果affine==True,则对归一化后的batch进行放射变换,即乘以模块内部的weight然后加上模块内部的bias,这两个变量都是网络训练时学习到的。

如果affine==False,则BatchNorm中不含有weight和bias两个变量,什么都不做。

总结

在使用batchNorm时,通常只需要指定num_features就可以了。网络训练前调用train(),训练时BatchNorm模块会统计全局running_mean和running_var,学习weight和bias,即文献中的

。网络评估前调用eval(),评估时,对传入的batch,使用统计的全局running_mean和running_var对batch进行归一化,然后使用学习到的weight和bias进行仿射变换。

编程小号
上一篇 2025-03-15 20:11
下一篇 2025-02-06 08:46

相关推荐

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/hz/119176.html