pytorch的Backward过程用时太长问题及解决
作者:Ai_Taoism 时间:2022-12-11 00:16:06
pytorch Backward过程用时太长
问题描述
使用pytorch对网络进行训练的时候遇到一个问题,forward阶段很快(只需要几毫秒),backward阶段却用时很长(需要十多秒)。
导致这个问题的原因很容易被大家忽视,而且网上基本上没有直接的解决方案,经过一天的折腾,总算把导致这个问题的原因搞清楚了。
解决方案
导致这个问题的原因在于训练数据的浅拷贝,由于backward过程中的梯度是和模型推理过程中的张量相关的,如果这些张量在被模型使用之前没有被深拷贝,意味着backward过程的会重复从这些张量的原始内存地址中取值,这个过程非常耗时。所以为了避免这个问题,需要养成一个好习惯,就是将张量数据输入模型之前进行深拷贝
pytorch的深拷贝方式如下:
tensor_a = tensor_b.clone().detach()
Pytorch backward()简单理解
backward()是反向传播求梯度,具体实现过程如下
import torch
x=torch.tensor([1,2,3],requires_grad=True,dtype=torch.double)
y=x**2
z=y.mean()
z.backward()
print(x.grad)
结果
tensor([0.6667, 1.3333, 2.0000], dtype=torch.float64)
有几个重要的点
1.必须要加上requires_grad=True才能求
2. 一般来说,需要标量才能求梯度。
3.具体过程如下:
z是一个标量(1*1矩阵)分别对x1,x2,x3求偏导, 再代入x1,x2,x3的数值,就是如上程序输出的结果
来源:https://blog.csdn.net/ahhhhhh520/article/details/124864850
标签:pytorch,Backward,过程
0
投稿
猜你喜欢
浅谈哪个Python库才最适合做数据可视化
2022-12-05 00:34:58
深度学习TextRNN的tensorflow1.14实现示例
2023-12-31 18:59:23
Python绘制组合图的示例
2023-07-30 01:34:31
详解Python 定时框架 Apscheduler原理及安装过程
2021-06-16 15:15:09
javascript一些不错的函数脚本代码
2023-07-02 05:25:52
python实现进度条的多种实现
2021-03-20 10:39:52
Python实现抓取网页并且解析的实例
2022-01-12 13:24:53
组件:Adodb.Stream 用法介绍
2008-10-09 12:39:00
python3爬虫中多线程的优势总结
2023-05-15 02:41:07
微信小程序实现2048小游戏的详细过程
2024-04-23 09:11:18
mysql数据库无法被其他ip访问的解决方法
2024-01-25 09:04:57
CSS框架的相关汇总(CSS Frameworks)
2008-04-02 12:00:00
thinkPHP框架实现类似java过滤器的简单方法示例
2023-11-22 12:24:47
Golang语言实现gRPC的具体使用
2024-05-05 09:26:19
详解Windows下PyCharm安装Numpy包及无法安装问题解决方案
2021-01-27 11:58:32
解决pycharm最左侧Tool Buttons显示不全的问题
2022-11-22 13:23:22
Python tornado队列示例-一个并发web爬虫代码分享
2022-03-13 12:13:55
php中使用key,value,current,next和prev函数遍历数组的方法
2023-10-18 20:17:39
Python的包管理器pip更换软件源的方法详解
2023-02-03 05:25:22
mysql分表分库的应用场景和设计方式
2024-01-22 05:49:39