浅谈PyTorch的可重复性问题(如何使实验结果可复现)

作者:hyk_1996 时间:2021-07-16 06:34:33 

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:


from torch.backends import cudnn
cudnn.benchmark = False      # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch


torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed)    # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)  # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。


import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:


GLOBAL_SEED = 1

def set_seed(seed):
 random.seed(seed)
 np.random.seed(seed)
 torch.manual_seed(seed)
 torch.cuda.manual_seed(seed)
 torch.cuda.manual_seed_all(seed)

GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
 global GLOBAL_WORKER_ID
 GLOBAL_WORKER_ID = worker_id
 set_seed(GLOBAL_SEED + worker_id)

dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

来源:https://blog.csdn.net/hyk_1996/article/details/84307108

标签:PyTorch,重复性,结果,复现
0
投稿

猜你喜欢

  • 说说CSS的优先权 考虑CSS的继承与层叠

    2008-12-11 13:33:00
  • sqlserver 不重复的随机数

    2024-01-14 00:13:59
  • PHP levenshtein()函数用法讲解

    2023-06-01 15:20:29
  • Python标准库uuid模块(生成唯一标识)详解

    2023-07-04 14:03:05
  • MySQL存储引擎简介及MyISAM和InnoDB的区别

    2024-01-26 23:53:17
  • Go语言atomic.Value如何不加锁保证数据线程安全?

    2024-04-25 13:16:52
  • Python实现自动签到脚本功能

    2022-07-24 21:53:40
  • 浅谈python锁与死锁问题

    2022-06-02 16:38:37
  • jquery精度计算代码 jquery指定精确小数位

    2024-05-21 10:20:21
  • Keras多线程机制与flask多线程冲突的解决方案

    2023-09-12 02:10:51
  • 100 个 Python 小例子(练习题四)

    2022-02-15 16:20:05
  • DenseNet121模型实现26个英文字母识别任务

    2023-08-22 13:15:22
  • 图文详解go语言反射实现原理

    2024-02-08 05:01:31
  • 无闪烁更新网页内容JS实现

    2024-05-09 10:37:18
  • JavaScript控制台的更多功能

    2024-02-24 12:46:42
  • python交易记录整合交易类详解

    2022-09-15 20:18:37
  • Python使用MapReduce编程模型统计销量

    2021-07-16 14:24:43
  • python使用Tkinter显示网络图片的方法

    2021-09-26 18:25:38
  • Python字典fromkeys()方法使用代码实例

    2021-07-09 09:54:38
  • layerUI下的绑定事件实例代码

    2024-04-16 09:38:08
  • asp之家 网络编程 m.aspxhome.com