Pytorch 中net.train 和 net.eval的使用说明

作者:Never-Giveup 时间:2021-11-15 11:40:37 

在训练模型时会在前面加上:


model.train()

在测试模型时在前面使用:


model.eval()

同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。

训练时是正对每个min-batch的,但是在测试中往往是针对单张图片,即不存在min-batch的概念。

由于网络训练完毕后参数都是固定的,因此每个批次的均值和方差都是不变的,因此直接结算所有batch的均值和方差。

所有Batch Normalization的训练和测试时的操作不同

在训练中,每个隐层的神经元先乘概率P,然后在进行激活,在测试中,所有的神经元先进行激活,然后每个隐层神经元的输出乘P。

补充:Pytorch踩坑记录——model.eval()

最近在写代码时遇到一个问题,原本训练好的模型,加载进来进行inference准确率直接掉了5个点,尼玛,这简直不能忍啊~本菜鸡下意识地感知到我肯定又在哪里写了bug了~~~于是开始到处排查,从model load到data load,最终在一个被我封装好的module的犄角旮旯里找到了问题,于是顺便就在这里总结一下,避免以后再犯。

对于训练好的模型加载进来准确率和原先的不符,比较常见的有两方面的原因:

1)data

2)model.state_dict()

1) data

数据方面,检查前后两次加载的data有没有发生变化。首先检查 transforms.Normalize 使用的均值和方差是否和训练时相同;另外检查在这个过程中数据是否经过了存储形式的改变,这有可能会带来数据精度的变化导致一定的信息丢失。

比如我过用的其中一个数据集,原先将图片存储成向量形式,但其对应的是“png”格式的数据(后来在原始文件中发现了相应的描述。),而我进行了一次data-to-img操作,将向量转换成了“jpg”形式,这时加载进来便造成了掉点。

2)model.state_dict()

第一方面造成的掉点一般不会太严重,第二方面造成的掉点就比较严重了,一旦模型的参数加载错了,那就误差大了。

如果是参数没有正确加载进来则比较容易发现,这时准确率非常低,几乎等于瞎猜。

而我这次遇到的情况是,准确率并不是特别低,只掉了几个点,检查了多次,均显示模型参数已经成功加载了。后来仔细查看后发现在其中一次调用模型进行inference时,忘了写 ‘model.eval()',造成了模型的参数发生变化,再次调用则出现了掉点。

于是又回顾了一下model.eval()和model.train()的具体作用。如下:

model.train() 和 model.eval() 一般在模型训练和评价的时候会加上这两句,主要是针对由于model 在训练时和评价时 Batch

Normalization 和 Dropout 方法模式不同:

a) model.eval(),不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会因BN层导致模型performance损失较大;

b) model.train() :启用 BatchNormalization 和 Dropout。 在模型测试阶段使用model.train() 让model变成训练模式,此时 dropout和batch normalization的操作在训练q起到防止网络过拟合的问题。

因此,在使用PyTorch进行训练和测试时一定要记得把实例化的model指定train/eval。


model.eval()   vs   torch.no_grad()

虽然二者都是eval的时候使用,但其作用并不相同:

model.eval() 负责改变batchnorm、dropout的工作方式,如在eval()模式下,dropout是不工作的。 见下方代码:


 import torch
 import torch.nn as nn

drop = nn.Dropout()
 x = torch.ones(10)

# Train mode  
 drop.train()
 print(drop(x)) # tensor([2., 2., 0., 2., 2., 2., 2., 0., 0., 2.])  

# Eval mode  
 drop.eval()
 print(drop(x)) # tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

torch.no_grad() 负责关掉梯度计算,节省eval的时间。

只进行inference时,model.eval()是必须使用的,否则会影响结果准确性。 而torch.no_grad()并不是强制的,只影响运行效率。

来源:https://blog.csdn.net/qq_36653505/article/details/84728489

标签:Pytorch,net.train,net.eval
0
投稿

猜你喜欢

  • Python使用循环神经网络解决文本分类问题的方法详解

    2022-12-01 16:49:05
  • Python爬虫之教你利用Scrapy爬取图片

    2022-11-02 10:35:02
  • Python数据可视化库seaborn的使用总结

    2022-08-07 11:43:04
  • python3格式化字符串 f-string的高级用法(推荐)

    2023-04-13 00:56:55
  • asp函数解决SQL注入漏洞

    2008-10-12 19:53:00
  • 如何在小空间放置大图片

    2009-08-04 13:04:00
  • Python容器使用的5个技巧和2个误区总结

    2023-04-09 04:37:01
  • 简介Python的collections模块中defaultdict类型的用法

    2021-01-04 20:14:54
  • 基于PHP读取csv文件内容的详解

    2023-11-16 04:17:48
  • 初学js者对javascript面向对象的认识分析

    2011-03-16 11:04:00
  • 解决python路径错误,运行.py文件,找不到路径的问题

    2023-03-13 05:47:33
  • 网站中视觉元素的设计

    2008-04-27 20:47:00
  • Python素数检测的方法

    2021-02-13 13:07:30
  • 如何实现SQL Server 2005快速Web分页

    2009-01-21 14:51:00
  • 基于PyQt5制作一个群发邮件工具

    2022-09-04 01:46:46
  • 带你轻松接触MySQL数据库的出错代码列表

    2008-12-31 15:06:00
  • IA学习笔记04:标签系统

    2009-09-22 14:40:00
  • Django2.1.3 中间件使用详解

    2023-11-06 19:46:00
  • ASP.NET 2.0防止同一用户同时登录

    2007-10-03 14:30:00
  • PHP实现逐行删除文件右侧空格的方法 <font color=red>原创</font>

    2023-11-22 05:11:25
  • asp之家 网络编程 m.aspxhome.com