PyTorch 如何检查模型梯度是否可导
作者:烟雨风渡 时间:2021-01-21 14:38:31
一、PyTorch 检查模型梯度是否可导
当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。
首先看一下官方文档中关于该函数的介绍:
可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:
Tensor需要是双精度浮点型且设置requires_grad = True
第一个例子:检查某一操作是否可导
from torch.autograd import gradcheck
import torch
import torch.nn as nn
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)
输出为:
Are the gradients correct: True
第二个例子:检查某一网络模型是否可导
from torch.autograd import gradcheck
import torch
import torch.nn as nn
# 定义神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(
nn.Linear(15, 30),
nn.ReLU(),
nn.Linear(30, 15),
nn.ReLU(),
nn.Linear(15, 1),
nn.Sigmoid()
)
def forward(self, x):
y = self.net(x)
return y
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)
输出为:
Are the gradients correct: True
二、Pytorch求导
1.标量对矩阵求导
验证:
>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]]) # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True) #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b) # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad #df/dX = a.dot(b^T)
tensor([[ 2., 3., 4.],
[ 4., 6., 8.],
[ 6., 9., 12.],
[ 8., 12., 16.]])
>>>a.grad b.grad # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1)) # a.dot(b^T)
tensor([[ 2., 3., 4.],
[ 4., 6., 8.],
[ 6., 9., 12.],
[ 8., 12., 16.]])
2.矩阵对矩阵求导
验证:
>>>A = torch.tensor([[1,2],[3,4.]]) #2*2矩阵
>>>X = torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True) # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
[19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
[6., 6., 6.]])
注意:
requires_grad为True的数组必须是float类型
进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)
来源:https://blog.csdn.net/tszupup/article/details/112916388
标签:PyTorch,检查,梯度
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
用Javascript正则表达式验证Email地址
2009-12-09 15:56:00
轻松接触SQL Server 2000实例的命名规则
2009-01-23 13:44:00
在 Python 中如何将天数添加到日期
2023-02-09 03:34:30
MySQL数据库性能优化妙招
2009-03-20 13:13:00
Python subprocess模块功能与常见用法实例详解
2021-08-30 02:46:43
![](https://img.aspxhome.com/file/2023/5/98865_0s.png)
ORACLE常见错误代码的分析与解决(一)
2010-08-02 13:20:00
Oracle 数据库操作技巧集
2010-07-26 12:49:00
SQL Server各种日期计算方法
2008-09-11 21:47:00
在python中利用try..except来代替if..else的用法
2023-09-12 17:50:10
Opencv+Python实现图像运动模糊和高斯模糊的示例
2022-08-06 12:25:19
![](https://img.aspxhome.com/file/2023/9/66349_0s.jpg)
PHP扩展Swoole实现实时异步任务队列示例
2023-11-10 05:11:22
python实现邮件发送功能
2023-10-11 02:27:09
iframe高度自适应,兼容IE,FF
2008-06-18 12:15:00
详解python使用canvas实现移动并绑定键盘
2022-08-18 01:02:52
![](https://img.aspxhome.com/file/2023/8/68368_0s.png)
python复制文件到指定目录的实例
2021-03-17 17:10:26
Mysql 数据库双机热备的配置方法
2010-06-09 19:13:00
Python爬虫基础之简单说一下scrapy的框架结构
2022-01-04 23:19:00
![](https://img.aspxhome.com/file/2023/8/87598_0s.png)
Python 操作mysql数据库查询之fetchone(), fetchmany(), fetchall()用法示例
2023-07-09 00:11:24
js 将json字符串转换为json对象的方法解析
2023-07-22 21:41:49
php flv视频时间获取函数
2023-09-04 13:41:48