pytorch 运行一段时间后出现GPU OOM的问题

作者:ASR_THU 时间:2021-05-21 17:01:34 

pytorch的dataloader会将数据传到GPU上,这个过程GPU的mem占用会逐渐增加,为了避免GPUmen被无用的数据占用,可以在每个step后用del删除一些变量,也可以使用torch.cuda.empty_cache()释放显存:


del targets, input_k, input_mask
torch.cuda.empty_cache()

这时能观察到GPU的显存一直在动态变化。

但是上述方式不是一个根本的解决方案,因为他受到峰值的影响很大。比如某个batch的数据量明显大于其他batch,可能模型处理该batch时显存会不够用,这也会导致OOM,虽然其他的batch都能顺利执行。

显存的占用跟这几个因素相关:

模型参数量

batch size

一个batch的数据 size

通常我们不希望改变模型参数量,所以只能通过动态调整batch-size,使得一个batch的数据 size不会导致显存OOM:


ilen = int(sorted_data[start][1]['input'][0]['shape'][0])
olen = int(sorted_data[start][1]['output'][0]['shape'][0])
# if ilen = 1000 and max_length_in = 800
# then b = batchsize / 2
# and max(1, .) avoids batchsize = 0
# 太长的句子会被动态改变bsz,单独成一个batch,否则padding的部分就太多了,数据量太大,OOM
factor = max(int(ilen / max_length_in), int(olen / max_length_out))
b = max(1, int(batch_size / (1 + factor)))
#b = batch_size
end = min(len(sorted_data), start + b)
minibatch.append(sorted_data[start:end])
if end == len(sorted_data):
   break
start = end

此外,如何选择一个合适的batchsize也是个很重要的问题,我们可以先对所有数据按照大小(长短)排好序(降序),不进行shuffle,按照64,32,16依次尝试bsz,如果模型在执行第一个batch的时候没出现OOM,那么以后一定也不会出现OOM(因为降序排列了数据,所以前面的batch的数据size最大)。

还有以下问题

pytorch increasing cuda memory OOM 问题

改了点model 的计算方式,然后就 OOM 了,调小了 batch_size,然后发现发现是模型每次迭代都会动态增长 CUDA MEMORY, 在排除了 python code 中的潜在内存溢出问题之后,基本可以把问题定在 pytorch 的图计算问题上了,说明每次迭代都重新生成了一张计算图,然后都保存着在,就 OOM 了。

参考

CUDA memory continuously increases when net(images) called in every iteration

Understanding graphs and state

说是会生成多个计算图:


loss = SomeLossFunction(out) + SomeLossFunction(out)

准备用 sum来避免多次生成计算图的问题:


loss = Variable(torch.sum(torch.cat([loss1, loss2], 0)))

然而,调着调着就好了,和报错前的 code 没太大差别。估计的原因是在pycharm 远程连接服务器的时候 code 的保存版本差异问题,这个也需要解决一下。

还有个多次迭代再计算梯度的问题,类似于 caffe中的iter_size,这个再仔细看看。

来源:https://blog.csdn.net/zongza/article/details/98647490

标签:pytorch,GPU,OOM
0
投稿

猜你喜欢

  • 发工资啦!教你用Python实现邮箱自动群发工资条

    2023-10-12 19:11:17
  • JavaScript高级程序设计 读书笔记之十 本地对象Date日期

    2024-04-22 22:33:48
  • python爬虫基础教程:requests库(二)代码实例

    2023-05-31 07:56:35
  • Golang编译器介绍

    2024-05-02 16:26:01
  • Scrapy实现模拟登录的示例代码

    2023-07-13 21:53:11
  • Python实现监控程序执行时间并将其写入日志的方法

    2023-01-15 01:35:53
  • react-native ListView下拉刷新上拉加载实现代码

    2023-07-02 06:35:34
  • CSS代码实现下划线样式的输入框效果

    2010-03-16 12:42:00
  • python中的__init__ 、__new__、__call__小结

    2021-07-19 20:10:38
  • ChatGpt无法访问或错误码1020的几种解决方案

    2023-03-03 05:58:36
  • 别人复制你网站的文章时自动加上注释

    2009-02-09 13:20:00
  • Go语言包管理模式示例分析

    2024-05-22 10:20:17
  • 用Python解数独的方法示例

    2021-01-31 18:38:44
  • python实现kmp算法的实例代码

    2022-07-19 15:07:16
  • WEB界面设计五种特征

    2010-03-16 12:34:00
  • Django 实现图片上传和下载功能

    2023-01-14 09:53:21
  • python多线程与多进程及其区别详解

    2021-10-10 04:04:49
  • MySQL 子查询和分组查询

    2024-01-18 22:05:36
  • 一文解答什么是MySQL的回表

    2024-01-18 02:41:56
  • 解析Python中while true的使用

    2022-07-23 21:19:53
  • asp之家 网络编程 m.aspxhome.com