详解model.train()和model.eval()两种模式的原理与用法

作者:想变厉害的大白菜 时间:2021-03-20 08:46:56 

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout。

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。

三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。

来源:https://blog.csdn.net/weixin_44211968/article/details/123774649

标签:model.train(),model.eval(),原理
0
投稿

猜你喜欢

  • python多线程http压力测试脚本

    2022-12-31 16:48:37
  • OpenCV3.0+Python3.6实现特定颜色的物体追踪

    2021-05-13 09:01:03
  • asp+jsp+JavaScript动态实现添加数据行

    2023-07-03 05:37:15
  • python网络编程:socketserver的基本使用方法实例分析

    2023-11-26 21:33:50
  • css中如何使div居中(垂直水平居中)

    2007-08-13 08:17:00
  • Python Tkinter基础控件用法

    2023-04-11 18:35:14
  • 详谈python3 numpy-loadtxt的编码问题

    2021-08-28 06:42:09
  • Python集合的增删改查操作

    2023-09-30 00:48:18
  • Python实现带图形界面的炸金花游戏(升级版)

    2023-06-27 08:35:20
  • 利用一个简单的例子窥探CPython内核的运行机制

    2023-08-11 04:54:31
  • python pandas分组聚合详细

    2022-01-27 22:21:44
  • 解决springboot yml配置 logging.level 报错问题

    2021-09-21 21:38:02
  • 使用Python批量修改文件名的代码实例

    2022-03-21 04:02:53
  • 一个输入框提示列表效果

    2008-03-09 18:53:00
  • JavaScript 日期联动选择器

    2010-08-01 10:18:00
  • Python实现绘制3D地球旋转效果

    2021-04-17 22:25:37
  • 嵌入Flash应该考虑不支持Flash的浏览器

    2007-12-20 12:29:00
  • ASP缓存类 【先锋缓存类】Ver2004

    2009-01-05 12:28:00
  • python文件目录操作之os模块

    2023-01-10 14:22:59
  • 网站注册那些事儿

    2010-01-05 16:49:00
  • asp之家 网络编程 m.aspxhome.com