pytorch DataLoader的num_workers参数与设置大小详解

作者:龙南希 时间:2022-12-22 12:15:58 

Q:在给Dataloader设置worker数量(num_worker)时,到底设置多少合适?这个worker到底怎么工作的?


   train_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=4)

参数详解:

1、每次dataloader加载数据时:dataloader一次性创建num_worker个worker,(也可以说dataloader一次性创建num_worker个工作进程,worker也是普通的工作进程),并用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。

然后,dataloader从RAM中找本轮迭代要用的batch,如果找到了,就使用。如果没找到,就要num_worker个worker继续加载batch到内存,直到dataloader在RAM中找到目标batch。一般情况下都是能找到的,因为batch_sampler指定batch时当然优先指定本轮要用的batch。

2、num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮...迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些。

3、如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度更慢。

设置大小建议:

1、Dataloader的num_worker设置多少才合适,这个问题是很难有一个推荐的值。有以下几个建议:

2、num_workers=0表示只有主进程去加载batch数据,这个可能会是一个瓶颈。

3、num_workers = 1表示只有一个worker进程用来加载batch数据,而主进程是不参与数据加载的。这样速度也会很慢。

num_workers>0 表示只有指定数量的worker进程去加载数据,主进程不参与。增加num_works也同时会增加cpu内存的消耗。所以num_workers的值依赖于 batch size和机器性能。

4、一般开始是将num_workers设置为等于计算机上的CPU数量

5、最好的办法是缓慢增加num_workers,直到训练速度不再提高,就停止增加num_workers的值。

补充:pytorch中Dataloader()中的num_workers设置问题

如果num_workers的值大于0,要在运行的部分放进__main__()函数里,才不会有错:


import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional
import matplotlib.pyplot as plt
import torch.utils.data as Data

BATCH_SIZE=5

x=torch.linspace(1,10,10)
y=torch.linspace(10,1,10)
torch_dataset=Data.TensorDataset(x,y)
loader=Data.DataLoader(
   dataset=torch_dataset,
   batch_size=BATCH_SIZE,
   shuffle=True,
   num_workers=2,
)

def main():
   for epoch in range(3):
       for step,(batch_x,batch_y) in enumerate(loader):
           # training....
           print('Epoch:',epoch,'| step:',step,'| batch x:',batch_x.numpy(),
                 '| batch y:',batch_y.numpy())

if __name__=="__main__":
   main()

'''
# 下面这样直接运行会报错:
for epoch in range(3):
    for step,(batch_x,batch_y) in enumerate(loader):
        # training....
         print('Epoch:',epoch,'| step:',step,'| batch x:',batch_x.numpy(),
                 '| batch y:',batch_y.numpy()
'''

来源:https://blog.csdn.net/qq_28057379/article/details/115427052

标签:pytorch,DataLoader,num,workers
0
投稿

猜你喜欢

  • javascript获取来源的URL代码

    2009-02-25 12:36:00
  • python爬取股票最新数据并用excel绘制树状图的示例

    2023-11-23 14:37:24
  • python+tkinter编写电脑桌面放大镜程序实例代码

    2023-08-02 17:10:43
  • sqlserver 巧妙的自关联运用

    2012-07-21 14:55:12
  • DropDownList绑定选择数据报错提示异常解决方案

    2023-07-18 04:36:13
  • js游戏 俄罗斯方块 源代码

    2008-01-24 13:14:00
  • Python入门教程(三十一)Python的Try和Except

    2022-02-26 01:21:14
  • CentOS7中使用shell脚本安装python3.8环境(推荐)

    2022-08-24 17:04:24
  • Python数据分析之使用scikit-learn构建模型

    2023-11-10 23:19:10
  • python处理二进制数据的方法

    2022-09-08 06:20:09
  • python list.sort()根据多个关键字排序的方法实现

    2021-05-22 03:16:09
  • javascript与jsp发送请求到servlet的几种方式实例

    2023-06-15 15:59:30
  • JavaScript 在各个浏览器中执行的耐性

    2009-02-06 15:26:00
  • 培养色感的一些经验分享

    2023-11-10 03:47:03
  • Python 使用 docopt 解析json参数文件过程讲解

    2021-06-30 21:44:00
  • python退出循环的方法

    2022-06-10 07:24:35
  • 一个网页设计师的成长经历

    2008-05-27 12:38:00
  • 详解Python中图像边缘检测算法的实现

    2021-02-08 09:18:27
  • 用 SQL 脚本将 Access 导入 MSSQL 2000/2005 方法

    2008-10-22 13:51:00
  • SQL Server默认1433端口修改方法

    2010-07-22 22:35:00
  • asp之家 网络编程 m.aspxhome.com