踩坑:pytorch中eval模式下结果远差于train模式介绍

作者:yucong96 时间:2021-10-06 22:27:49 

首先,eval模式和train模式得到不同的结果是正常的。我的模型中,eval模式和train模式不同之处在于Batch Normalization和Dropout。Dropout比较简单,在train时会丢弃一部分连接,在eval时则不会。Batch Normalization,在train时不仅使用了当前batch的均值和方差,也使用了历史batch统计上的均值和方差,并做一个加权平均(momentum参数)。在test时,由于此时batchsize不一定一致,因此不再使用当前batch的均值和方差,仅使用历史训练时的统计值。

我出bug的现象是,train模式下可以收敛,但一旦在测试中切换到了eval模式,结果就很差。如果在测试中仍沿用train模式,反而可以得到不错的结果。为了确保是程序bug而不是算法本身就不适合于预测,我在测试时再次使用了训练集,正常情况下此时应发生过拟合,正确率一定会很高,然而eval模式下正确率仍然很低。参照网上的一些说法(Performance highly degraded when eval() is activated in the test phase
),我调大了batchsize,降低了BN层的momentum,检查了是否存在不同层使用相同BN层的bug,均不见效。有一种方法说应在BN层设置track_running_stats为False,它虽然带来了好的效果,但实际上它只不过是不用eval模式,切回train模式罢了,所以也不对。

学习了在训练过程中,如何将BN层中统计的均值和方差输出。即在forward()中,


# bn是一个BN层,torch.nn.batch_normalization(...)
print(bn.running_mean)
print(bn.running_var)

同时学习了如何输出一个Tensor自身的均值和方差,即


# x是一个Tensor,dims是需要计算的维度
print(x.cpu().detach().numpy().mean(dims)
print(x.cpu().detach().numpy().var(dims)

观察每一层的输出结果,发现出现了很大的方差,才猛然意识到自己的输入数据没有做归一化(事后想想也确实如此,毕竟模型和训练方法都是github上参考别人的,出错概率很小;反而是自己写的DataSet部分,其实是最容易出错的)。给模型加上归一化后,eval和train的结果就没有问题了。

再次验证了我的观点:越是玄学的问题,越是 * 的bug。

补充知识:Pytorch中的train和eval用法注意点

1.介绍

一般情况,model.train()是在训练的时候用到,model.eval()是在测试的时候用到

2.用法

如果模型中没有类似于BN这样的归一化或者Dropout,model.train()和model.eval()可以不要(建议写一下,比较安全),并且model.train()和model.eval()得到的效果是一样

如果模型中有类似于BN这样的归一化或者Dropout,并且程序需要边训练和边测试,最好就是用model.eval()测试完之后,后面补一个model.train()。

其中model.train()是保证BN用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接(结果是取了平均)

来源:https://blog.csdn.net/yucong96/article/details/88652964

标签:pytorch,eval,train
0
投稿

猜你喜欢

  • 详解python 拆包可迭代数据如tuple, list

    2022-01-08 19:28:43
  • 详解Python中break语句的用法

    2021-12-21 22:18:17
  • asp日期转换成汉字格式程序

    2008-07-08 18:19:00
  • 教大家使用Python SqlAlchemy

    2022-12-02 01:40:17
  • HTML的基本元素

    2010-03-16 12:39:00
  • shtml网页SSI使用详解

    2008-02-20 19:13:00
  • Python中更优雅的日志记录方案详解

    2023-09-02 13:43:03
  • 动态SQL中返回数值的实现代码

    2012-01-05 18:53:54
  • Python中的Numpy矩阵操作

    2021-10-06 07:19:23
  • pandas 把数据写入txt文件每行固定写入一定数量的值方法

    2021-06-13 20:08:14
  • Python3.10耙梳加密算法Encryption种类及开发场景

    2021-07-19 00:46:55
  • Python全栈之列表数据类型详解

    2023-05-05 15:27:10
  • PHP实现克鲁斯卡尔算法实例解析

    2023-09-08 19:35:57
  • python scrapy框架中Request对象和Response对象的介绍

    2021-04-02 07:29:59
  • ASP中遍历和操作Application对象的集合

    2007-09-13 12:45:00
  • rs.open sql,conn,1,1与rs.open sql,conn,1.3还有rs.open sql,conn,3,2区别

    2011-02-24 10:49:00
  • 利用ASP从远程服务器上接收XML数据

    2007-08-23 12:49:00
  • 友情连接地址代码-线线表格

    2010-07-01 16:26:00
  • Numpy对于NaN值的判断方法

    2022-12-15 15:08:21
  • asp IsValidEmail 验证邮箱地址函数(email)

    2011-03-03 10:42:00
  • asp之家 网络编程 m.aspxhome.com