pytorch 实现在测试的时候启用dropout

作者:qian99 时间:2022-03-16 18:08:22 

我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?

在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。

想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:


def apply_dropout(m):
   if type(m) == nn.Dropout:
       m.train()

下面是完整demo代码:


# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
   def __init__(self):
       super(SimpleNet, self).__init__()
       self.fc = nn.Linear(8, 8)
       self.dropout = nn.Dropout(0.5)
   def forward(self, x):
       x = self.fc(x)
       x = self.dropout(x)
       return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
   if type(m) == nn.Dropout:
       m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)

运行结果:

pytorch 实现在测试的时候启用dropout

可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。

补充:Pytorch之dropout避免过拟合测试

一.做数据

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

二.搭建神经网络

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

三.训练

pytorch 实现在测试的时候启用dropout

四.对比测试结果

注意:测试过程中,一定要注意模式切换

pytorch 实现在测试的时候启用dropout

pytorch 实现在测试的时候启用dropout

来源:https://blog.csdn.net/qian99/article/details/89052262

标签:pytorch,测试,启用,dropout
0
投稿

猜你喜欢

  • Firefox 3.5 新增加的支持(整理)

    2009-08-01 12:51:00
  • asp.net微信开发(永久素材管理)

    2023-07-21 13:02:45
  • python 实现创建文件夹和创建日志文件的方法

    2023-07-07 11:35:10
  • 详解Python中的多线程编程

    2023-09-17 00:34:08
  • 判断数据库表是否存在以及修改表名的方法

    2024-01-22 09:21:24
  • Windows11下MySQL 8.0.29 安装配置方法图文教程

    2024-01-24 09:20:40
  • Python 中pandas索引切片读取数据缺失数据处理问题

    2021-06-02 05:13:28
  • 浅谈django2.0 ForeignKey参数的变化

    2022-03-26 10:11:30
  • 用sleep间隔进行python反爬虫的实例讲解

    2023-02-10 07:00:42
  • 用Python编写一个简单的CS架构后门的方法

    2021-08-07 00:15:58
  • SQLServer导出sql文件/表架构和数据操作步骤

    2024-01-26 19:21:26
  • 深入理解python中函数传递参数是值传递还是引用传递

    2022-02-21 10:08:33
  • 运行python脚本更改Windows背景

    2022-06-11 05:36:54
  • flask应用部署到服务器的方法

    2023-11-25 16:59:38
  • laravel容器延迟加载以及auth扩展详解

    2024-06-05 09:45:06
  • NaviCat连接时提示"不支持远程连接的MySql数据库"解决方法

    2024-01-24 17:03:54
  • vue 使用localstorage实现面包屑的操作

    2024-05-10 14:19:40
  • python 求两个向量的顺时针夹角操作

    2021-05-26 13:31:10
  • Python如何判断字符串是否仅包含数字

    2023-12-23 08:16:04
  • Django REST为文件属性输出完整URL的方法

    2023-07-29 02:42:42
  • asp之家 网络编程 m.aspxhome.com