PyTorch 如何检查模型梯度是否可导

作者:烟雨风渡 时间:2021-01-21 14:38:31 

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导


from torch.autograd import gradcheck
import torch
import torch.nn as nn

inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导


from torch.autograd import gradcheck
import torch
import torch.nn as nn
# 定义神经网络模型
class Net(nn.Module):

def __init__(self):
       super(Net, self).__init__()
       self.net = nn.Sequential(
           nn.Linear(15, 30),
           nn.ReLU(),
           nn.Linear(30, 15),
           nn.ReLU(),
           nn.Linear(15, 1),
           nn.Sigmoid()
       )

def forward(self, x):
       y = self.net(x)
       return y

net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:


>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
   [ 4.,  6.,  8.],
   [ 6.,  9., 12.],
   [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
   tensor([[ 2.,  3.,  4.],
   [ 4.,  6.,  8.],
   [ 6.,  9., 12.],
   [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导 PyTorch 如何检查模型梯度是否可导

验证:


>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
   [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
   [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

来源:https://blog.csdn.net/tszupup/article/details/112916388

标签:PyTorch,检查,梯度
0
投稿

猜你喜欢

  • ASP与MySQL的连接[图文教程]

    2010-03-14 11:21:00
  • 浅谈python量化 双均线策略(金叉死叉)

    2022-05-28 02:21:58
  • Python全栈之线程详解

    2021-05-21 17:44:21
  • 从错误中学习改正Go语言五个坏习惯提高编程技巧

    2023-10-12 20:06:33
  • JavaScript面试必考之实现手写Promise

    2024-04-16 10:38:49
  • Python读取Excel表格,并同时画折线图和柱状图的方法

    2023-12-25 07:11:27
  • python 调用c语言函数的方法

    2023-12-11 17:24:21
  • js调用设备摄像头的方法

    2024-04-17 09:46:46
  • 利用python对Excel中的特定数据提取并写入新表的方法

    2023-09-17 16:03:10
  • 利用MySqlBulkLoader实现批量插入数据的示例详解

    2024-01-24 08:46:00
  • 原生javascript AJAX 三级联动的实现代码

    2024-04-18 10:00:46
  • Python学习之时间包使用教程详解

    2022-07-18 11:26:39
  • asp 实现当有新信息时播放语音提示的效果

    2011-03-31 11:00:00
  • js数组去重的11种方法

    2024-04-17 10:30:54
  • 使用Cython中prange函数实现for循环的并行

    2023-04-13 05:31:55
  • 安装了Office2003补丁之后,access不能用,打不开了

    2011-05-12 12:19:00
  • 如何使用VUE+faceApi.js实现摄像头拍摄人脸识别

    2023-07-02 16:32:04
  • 微信小程序实现给嵌套template模板传递数据的方式总结

    2024-05-22 10:31:50
  • MySQL的一些安全注意点

    2008-12-24 16:29:00
  • python flask 如何修改默认端口号的方法步骤

    2021-07-04 16:35:14
  • asp之家 网络编程 m.aspxhome.com