pytorch训练时的显存占用递增的问题解决

作者:来包番茄沙司 时间:2021-04-20 07:12:45 

遇到的问题:

在pytorch训练过程中突然out of memory。

解决方法:

1. 测试的时候爆显存有可能是忘记设置no_grad

加入 with torch.no_grad()

model.eval()
with torch.no_grad():
        for idx, (data, target) in enumerate(data_loader):
            if args.gpu != -1:
                data, target = data.to(args.device), target.to(args.device)
            log_probs = net_g(data)
            probs.append(log_probs)
            
            # sum up batch loss
            test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
            # get the index of the max log-probability
            y_pred = log_probs.data.max(1, keepdim=True)[1]
            correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

2. loss.item()

写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss

3. 在代码中添加以下两行:

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

4. del操作后再加上torch.cuda.empty_cache()

单独使用del、torch.cuda.empty_cache()效果都不明显,因为empty_cache()不会释放还被占用的内存。
所以这里使用了del让对应数据成为“没标签”的垃圾,之后这些垃圾所占的空间就会被empty_cache()回收。

"""添加了最后两行,img和segm是图像和标签输入,很明显通过.cuda()已经是被存在在显存里了;
   outputs是模型的输出,模型在显存里当然其输出也在显存里;loss是通过在显存里的segm和
   outputs算出来的,其也在显存里。这4个对象都是一次性的,使用后应及时把其从显存中清除
   (当然如果你显存够大也可以忽略)。"""
 
def train(model, data_loader, batch_size, optimizer):
    model.train()
    total_loss = 0
    accumulated_steps = 32 // batch_size
    optimizer.zero_grad()
    for idx, (img, segm) in enumerate(tqdm(data_loader)):
        img = img.cuda()
        segm = segm.cuda()
        outputs = model(img)
        loss = criterion(outputs, segm)
        (loss/accumulated_steps).backward()
        if (idx + 1 ) % accumulated_steps == 0:
            optimizer.step() 
            optimizer.zero_grad()
        total_loss += loss.item()
        
        # delete caches
        del img, segm, outputs, loss
        torch.cuda.empty_cache()

补充:Pytorch显存不断增长问题的解决思路

思路很简单,就是在代码的运行阶段输出显存占用量,观察在哪一块存在显存剧烈增加或者显存异常变化的情况。
但是在这个过程中要分级确认问题点,也即如果存在三个文件main.py、train.py、model.py。
在此种思路下,应该先在main.py中确定问题点,然后,从main.py中进入到train.py中,再次输出显存占用量,确定问题点在哪。
随后,再从train.py中的问题点,进入到model.py中,再次确认。
如果还有更深层次的调用,可以继续追溯下去。

例如:

main.py

def train(model,epochs,data):
   for e in range(epochs):
       print("1:{}".format(torch.cuda.memory_allocated(0)))
       train_epoch(model,data)
       print("2:{}".format(torch.cuda.memory_allocated(0)))
       eval(model,data)
       print("3:{}".format(torch.cuda.memory_allocated(0)))

若1与2之间显存增加极为剧烈,说明问题出在train_epoch中,进一步进入到train.py中。

train.py

def train_epoch(model,data):
   model.train()
   optim=torch.optimizer()
   for batch_data in data:
       print("1:{}".format(torch.cuda.memory_allocated(0)))
       output=model(batch_data)
       print("2:{}".format(torch.cuda.memory_allocated(0)))
       loss=loss(output,data.target)
       print("3:{}".format(torch.cuda.memory_allocated(0)))
       optim.zero_grad()
       print("4:{}".format(torch.cuda.memory_allocated(0)))
       loss.backward()
       print("5:{}".format(torch.cuda.memory_allocated(0)))
       utils.func(model)
       print("6:{}".format(torch.cuda.memory_allocated(0)))

如果在1,2之间,5,6之间同时出现显存增加异常的情况。此时需要使用控制变量法,例如我们先让5,6之间的代码失效,然后运行,观察是否仍然存在显存 * 。如果没有,说明问题就出在5,6之间下一级的代码中。进入到下一级代码,进行调试:

utils.py

def func(model):
   print("1:{}".format(torch.cuda.memory_allocated(0)))
   a=f1(model)
   print("2:{}".format(torch.cuda.memory_allocated(0)))
   b=f2(a)
   print("3:{}".format(torch.cuda.memory_allocated(0)))
   c=f3(b)
   print("4:{}".format(torch.cuda.memory_allocated(0)))
   d=f4(c)
   print("5:{}".format(torch.cuda.memory_allocated(0)))

此时我们再展示另一种调试思路,先注释第5行之后的代码,观察显存是否存在先训 * ,如果没有,则注释掉第7行之后的,直至确定哪一行的代码出现导致了显存 * 。假设第9行起作用后,代码出现显存 * ,说明问题出在第九行,显存 * 的问题锁定。

参考链接:
http://www.zzvips.com/article/196059.html
https://blog.csdn.net/fish_like_apple/article/details/101448551

来源:https://blog.csdn.net/weixin_45928096/article/details/128691564

标签:pytorch,显存占用,递增
0
投稿

猜你喜欢

  • python的urllib模块显示下载进度示例

    2023-06-13 17:06:31
  • 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    2023-11-25 12:41:38
  • Python实现统计文本中的字符数量

    2021-01-28 00:27:04
  • 利用PyQt5中QLabel组件实现亚克力磨砂效果

    2023-12-13 18:33:04
  • golang时间/时间戳的获取与转换实例代码

    2023-09-02 06:04:43
  • 使用Python的urllib2模块处理url和图片的技巧两则

    2022-02-15 21:26:00
  • 交互设计师心得——核心竞争力

    2010-01-19 13:45:00
  • python unichr函数知识点总结

    2022-02-03 11:48:31
  • Python数据正态性检验实现过程

    2022-07-10 15:46:14
  • WAP设计基础

    2011-01-06 12:13:00
  • 游戏的用户体验营销小札

    2009-08-30 15:13:00
  • asp如何编写sql语句来查询|搜索数据记录

    2008-10-09 12:35:00
  • python中异常捕获方法详解

    2021-10-30 10:06:09
  • 教程:MySQL中多表操作和批处理方法

    2009-07-30 08:20:00
  • 关于ASP循环表格的问题之解答[比较详细]

    2011-04-08 11:14:00
  • 记录密码的asp代码

    2009-11-02 10:50:00
  • 设计师如何更有效拿到结果?

    2008-09-22 20:30:00
  • Python序列的推导式实现代码

    2022-04-24 05:53:46
  • Python实现绘制多角星实例

    2023-08-26 13:42:14
  • Python基于pygame实现的font游戏字体(附源码)

    2021-04-16 05:06:17
  • asp之家 网络编程 m.aspxhome.com