详解model.train()和model.eval()两种模式的原理与用法

作者:想变厉害的大白菜 时间:2021-03-20 08:46:56 

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout。

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。

三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。

来源:https://blog.csdn.net/weixin_44211968/article/details/123774649

标签:model.train(),model.eval(),原理
0
投稿

猜你喜欢

  • pyCaret效率倍增开源低代码的python机器学习工具

    2021-01-09 10:30:26
  • Python学习之pip包管理工具的使用

    2023-07-24 11:01:57
  • 基于MySql的扩展功能生成全局ID

    2024-01-13 07:52:58
  • SQL Server 2016 TempDb里的显著提升

    2024-01-24 17:15:47
  • MySQL性能优化神器Explain的基本使用分析

    2024-01-19 21:56:15
  • Flask中基于Token的身份认证的实现

    2022-11-20 06:45:53
  • MySQL常用分库分表方案汇总

    2024-01-18 10:51:14
  • 如何安装2019Pycharm最新版本(详细教程)

    2022-09-19 12:20:54
  • php实现mysql备份恢复分卷处理的方法

    2023-11-16 20:55:33
  • ES6中let 和 const 的新特性

    2024-05-28 15:41:41
  • Python简单实现安全开关文件的两种方式

    2022-09-15 01:54:38
  • python使用matplotlib绘制柱状图教程

    2021-01-29 20:20:00
  • python 中的列表解析和生成表达式

    2022-01-30 16:14:15
  • 解决Windows 7下安装Oracle 11g相关问题的方法

    2024-01-22 05:19:19
  • Python数据分析中Groupby用法之通过字典或Series进行分组的实例

    2023-03-08 12:56:01
  • 详解Git 的 rebase 命令使用方法

    2023-04-16 16:57:08
  • 页面新开窗口的一点补充

    2008-09-10 12:57:00
  • 浅谈python多线程和多线程变量共享问题介绍

    2022-08-29 04:34:18
  • Python如何快速生成本项目的requeirments.txt实现

    2023-07-21 17:02:04
  • python之OpenCV的作用以及安装案例教程

    2021-11-27 07:14:20
  • asp之家 网络编程 m.aspxhome.com