踩坑:pytorch中eval模式下结果远差于train模式介绍

作者:yucong96 时间:2021-10-06 22:27:49 

首先,eval模式和train模式得到不同的结果是正常的。我的模型中,eval模式和train模式不同之处在于Batch Normalization和Dropout。Dropout比较简单,在train时会丢弃一部分连接,在eval时则不会。Batch Normalization,在train时不仅使用了当前batch的均值和方差,也使用了历史batch统计上的均值和方差,并做一个加权平均(momentum参数)。在test时,由于此时batchsize不一定一致,因此不再使用当前batch的均值和方差,仅使用历史训练时的统计值。

我出bug的现象是,train模式下可以收敛,但一旦在测试中切换到了eval模式,结果就很差。如果在测试中仍沿用train模式,反而可以得到不错的结果。为了确保是程序bug而不是算法本身就不适合于预测,我在测试时再次使用了训练集,正常情况下此时应发生过拟合,正确率一定会很高,然而eval模式下正确率仍然很低。参照网上的一些说法(Performance highly degraded when eval() is activated in the test phase
),我调大了batchsize,降低了BN层的momentum,检查了是否存在不同层使用相同BN层的bug,均不见效。有一种方法说应在BN层设置track_running_stats为False,它虽然带来了好的效果,但实际上它只不过是不用eval模式,切回train模式罢了,所以也不对。

学习了在训练过程中,如何将BN层中统计的均值和方差输出。即在forward()中,


# bn是一个BN层,torch.nn.batch_normalization(...)
print(bn.running_mean)
print(bn.running_var)

同时学习了如何输出一个Tensor自身的均值和方差,即


# x是一个Tensor,dims是需要计算的维度
print(x.cpu().detach().numpy().mean(dims)
print(x.cpu().detach().numpy().var(dims)

观察每一层的输出结果,发现出现了很大的方差,才猛然意识到自己的输入数据没有做归一化(事后想想也确实如此,毕竟模型和训练方法都是github上参考别人的,出错概率很小;反而是自己写的DataSet部分,其实是最容易出错的)。给模型加上归一化后,eval和train的结果就没有问题了。

再次验证了我的观点:越是玄学的问题,越是 * 的bug。

补充知识:Pytorch中的train和eval用法注意点

1.介绍

一般情况,model.train()是在训练的时候用到,model.eval()是在测试的时候用到

2.用法

如果模型中没有类似于BN这样的归一化或者Dropout,model.train()和model.eval()可以不要(建议写一下,比较安全),并且model.train()和model.eval()得到的效果是一样

如果模型中有类似于BN这样的归一化或者Dropout,并且程序需要边训练和边测试,最好就是用model.eval()测试完之后,后面补一个model.train()。

其中model.train()是保证BN用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接(结果是取了平均)

来源:https://blog.csdn.net/yucong96/article/details/88652964

标签:pytorch,eval,train
0
投稿

猜你喜欢

  • 使用MySQL数据库的23个注意事项

    2010-03-18 15:46:00
  • numpy数组坐标轴问题解决

    2022-10-23 02:48:12
  • django实现日志按日期分割

    2023-07-20 04:25:21
  • Mysql中Join的使用实例详解

    2024-01-26 05:04:36
  • Bootstrapvalidator校验、校验清除重置的实现代码(推荐)

    2024-04-10 13:52:57
  • pyMySQL SQL语句传参问题,单个参数或多个参数说明

    2024-01-18 13:21:33
  • Python Multiprocessing多进程 使用tqdm显示进度条的实现

    2021-04-03 19:15:08
  • 巧妙的Sql函数日期处理方法

    2009-05-25 17:59:00
  • 用 Python 元类的特性实现 ORM 框架

    2022-02-12 12:45:24
  • Python中celery的使用

    2022-10-22 14:03:17
  • django定期执行任务(实例讲解)

    2022-12-13 20:43:35
  • javascript框架设计之框架分类及主要功能

    2024-04-18 09:33:40
  • Python实现FTP弱口令扫描器的方法示例

    2023-12-16 04:46:45
  • HTML+JS实现猜拳游戏的示例代码

    2024-04-16 09:31:25
  • 深入MYSQL字符数字转换的详解

    2024-01-18 04:20:11
  • 让字体美起来

    2011-06-14 09:50:21
  • caffe的python接口生成配置文件学习

    2023-07-09 04:46:41
  • 在Python中使用NLTK库实现对词干的提取的教程

    2022-11-04 15:13:53
  • Vscode常用快捷键列表、插件安装、console.log详解

    2023-02-11 01:29:04
  • Go+Kafka实现延迟消息的实现示例

    2024-05-22 10:14:29
  • asp之家 网络编程 m.aspxhome.com