pytorch中可视化之hook钩子

作者:阿瓦达啃大瓜~ 时间:2021-07-30 04:26:23 

一、hook

在PyTorch中,提供了一个专用的接口使得网络在前向传播过程中能够获取到特征图,这个接口的名称非常形象,叫做hook。
可以想象这样的场景,数据通过网络向前传播,网络某一层我们预先设置了一个钩子,数据传播过后钩子上会留下数据在这一层的样子,读取钩子的信息就是这一层的特征图。
具体实现如下:

1.1 什么是hook,什么情况下使用?

首先,明确一下,为什么需要用hook,假设有这么一个函数

pytorch中可视化之hook钩子

需要通过梯度下降法求最小值,其实现方法如下:

import torch
x = torch.tensor(3.0, requires_grad=True)
y = (x-2)
z = ((y-x) ** 2)
z.backward()
print("x.grad:",x.requires_grad,x.grad)
print("y.grad:",y.requires_grad,y.grad)
print("z.grad:",z.requires_grad,z.grad)

结果如下:

x.grad: True tensor(0.)
y.grad: True None
z.grad: True None

注意:在使用训练PyTorch训练模型时,只有叶节点(即直接指定数值的变量,而不是由其他变量计算得到的,比如网络输入)的梯度会保留,其余中间节点梯度在反向传播完成后就会自动释放以节省显存。 因此y.requires_grad的返回值为True,y.grad却为None。

可以看到上面的requires_grad方法都显示True,但是grad没有返回值。当然pytorch也提供某种方法保留非叶子节点的梯度信息。
使用 retain_grad() 方法可以保留非叶子节点的梯度,使用 retain_grad 保留的grad会占用显存,具体操作如下:

x = torch.tensor(3.0, requires_grad=True)
y = (x-2)
z = ((y-x) ** 2)
y.retain_grad()
z.retain_grad()
z.backward()
print("x.grad:",x.requires_grad,x.grad)
print("y.grad:",y.requires_grad,y.grad)
print("z.grad:",z.requires_grad,z.grad)

out:

x.grad: True tensor(0.)
y.grad: True tensor(-4.)
z.grad: True tensor(1.)

** 重申一次** 使用retain_grad方法会占用显存,如果不想要占用显存,就使用到了hook方法。

对于中间节点的变量a,可以使用a.register_hook(hook_fn)对其grad进行操作。 而hook_fn是一个自定义的函数,其声明为hook_fn(grad) -> Tensor or None

1.2 hook在变量中的使用

1.2.1 hook的打印功能

# 自定义hook方法,其传入参数为grad,打印出使用钩子的节点梯度
def hook_fn(grad):
   print(grad)

x = torch.tensor(3.0, requires_grad=True)
y = (x-2)
z = ((y-x) ** 2)
y.register_hook(hook_fn)
z.register_hook(hook_fn)
print("backward前")

z.backward()
print("backward后\n")
print("x.grad:",x.requires_grad,x.grad)
print("y.grad:",y.requires_grad,y.grad)
print("z.grad:",z.requires_grad,z.grad)

out:

backward前
tensor(1.)
tensor(-4.)
backward后

x.grad: True tensor(0.)
y.grad: True None
z.grad: True None

可以看到绑定hook后,backward打印的时候打印了y和z的梯度,调用grad的时候没有保留grad值,已经释放掉内存。注意,打印出来的结果是反向传播,所以先打印z的梯度,再打印y的梯度。

1.2.2 使用hook改变grad的功能

对标记的节点,梯度加2

def hook_fn(grad):
   grad += 2
   print(grad)
   return grad

x = torch.tensor(3.0, requires_grad=True)
y = (x-2)
z = ((y-x) ** 2)
y.register_hook(hook_fn)
z.register_hook(hook_fn)
print("backward前")

z.backward()
print("backward后\n")
print("x.grad:",x.requires_grad,x.grad)
print("y.grad:",x.requires_grad,y.grad)
print("z.grad:",x.requires_grad,z.grad)

out:

backward前
tensor(3.)
tensor(-10.)
backward后

x.grad: True tensor(2.)
y.grad: True None
z.grad: True None

可以看到梯度教上面的已经发生的改变。

1.3 hook在模型中的使用:

PyTorch中使用register_forward_hook和register_backward_hook获取Module输入和输出的feature_map和grad。使用结构如下: hook_fn(module, input, output) -> Tensor or None
模型中使用hook一点要带有这三个参数module, grad_input, grad_output

1.3.1 register_forward_hook的使用

import torch.nn as nn

def hook_forward_fn(model,put,out):
   print("model:",model)
   print("input:",put)
   print("output:",out)

# 定义一个model
class Net(nn.Module):
   def __init__(self):
       super(Net,self).__init__()
       self.conv = nn.Conv2d(3, 1, 1)
       self.bn = nn.BatchNorm2d(1)
       #self.conv.register_forward_hook(hook_forward_fn)
       #self.bn.register_forward_hook(hook_forward_fn)

def forward(self, x):
       x = self.conv(x)
       x = self.bn(x)
       return torch.relu(x)

net = Net()
# 对模型中的具体某一层使用hook
net.conv.register_forward_hook(hook_forward_fn)
net.bn.register_forward_hook(hook_forward_fn)

x = torch.rand(1, 3, 2, 2, requires_grad=True)
y = net(x).mean()

注意:该方法不需要使用。backword就能输出结果,是记录前向传播的钩子。
结果如下:

model: Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1))
input: (tensor([[[[0.4570, 0.6791],
         [0.0197, 0.5040]],

[[0.8883, 0.1808],
         [0.6289, 0.9386]],

[[0.8772, 0.5290],
         [0.0014, 0.3728]]]], requires_grad=True),)
output: tensor([[[[-0.4909, -0.1122],
         [-0.6301, -0.5649]]]], grad_fn=<ConvolutionBackward0>)
model: BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
input: (tensor([[[[-0.4909, -0.1122],
         [-0.6301, -0.5649]]]], grad_fn=<ConvolutionBackward0>),)
output: tensor([[[[-0.2060,  1.6790],
         [-0.8987, -0.5743]]]], grad_fn=<NativeBatchNormBackward0>)

1.3.2 register_backward_hook的使用

使用上面相同的Net模型

def hook_backward_fn(module, grad_input, grad_output):
   print(f"module: {module}")
   print(f"grad_output: {grad_output}")
   print(f"grad_input: {grad_input}")
   print("*"*20)

net = Net()
net.conv.register_backward_hook(hook_backward_fn)
net.bn.register_backward_hook(hook_backward_fn)
x = x = torch.rand(1, 3, 2, 2, requires_grad=True)
y = net(x).mean()
y.backward()

out:

module: BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
grad_output: (tensor([[[[0.2500, 0.2500],
         [0.0000, 0.0000]]]]),)
grad_input: (tensor([[[[ 0.6586, -0.3360],
         [-0.3009, -0.0218]]]]), tensor([0.4575]), tensor([0.5000]))
********************
module: Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1))
grad_output: (tensor([[[[ 0.6586, -0.3360],
         [-0.3009, -0.0218]]]]),)
grad_input: (tensor([[[[-0.2974,  0.1517],
         [ 0.1359,  0.0098]],

[[ 0.0270, -0.0138],
         [-0.0123, -0.0009]],

[[ 0.2918, -0.1489],
         [-0.1333, -0.0096]]]]), tensor([[[[0.4331]],

[[0.1386]],

[[0.4292]]]]), tensor([-1.4156e-07]))
********************

其结果是逆向输出各节点层的梯度信息。

1.3.3 hook中使用展示卷积层

随便画一张图,图片张这个样子:

pytorch中可视化之hook钩子

使用读取图片发现是个4通道的图像,我们转成单通道并可视化:

import matplotlib.pyplot as plt
import matplotlib.image as mping
img=mping.imread("./test1.png")
print(img.shape)
img = torch.tensor(img[:,:,0]).view(1,1,228,226)
plt.imshow(img[0][0])

pytorch中可视化之hook钩子

接下来创建一个只有卷积层的模型

class Net(nn.Module):
   def __init__(self):
       super(Net,self).__init__()
       self.conv = nn.Sequential(nn.Conv2d(1,1,7),
                                 nn.ReLU()
                                )

def forward(self, x):
       x=self.conv(x)
       return x

使用我们的钩子hook对卷积层的输出进行可视化

def hook_forward_fn(model,put,out):
   print("inputshape:",put[0].shape) # 打印出输入图片的维度
   print("outputshape:",out[0][0].shape) # 经过卷积之后的维度
   # 可视化,因为卷积之后带有grad梯度信息,所以需要使用detach().numpy()方法,否则会报错
   plt.imshow(out[0][0].detach().numpy())

具体完整实现以及可视化代码如下:

import matplotlib.pyplot as plt
import matplotlib.image as mping
import numpy as np

img=mping.imread("./test1.png")
img = torch.tensor(img[:,:,0]).view(1,1,228,226)

def hook_forward_fn(model,put,out):
   print("inputshape:",put[0].shape)
   print("outputshape:",out[0][0].shape)
   plt.imshow(out[0][0].detach().numpy())

class Net(nn.Module):
   def __init__(self):
       super(Net,self).__init__()
       self.conv = nn.Sequential(nn.Conv2d(1,1,7),
                                 nn.ReLU()
                                )

def forward(self, x):
       x=self.conv(x)
       return x

model = Net()
model.conv.register_forward_hook(hook_forward_fn)
y=model(img)

pytorch中可视化之hook钩子

来源:https://blog.csdn.net/weixin_41555165/article/details/127454644

标签:pytorch,hook,钩子
0
投稿

猜你喜欢

  • 浅析SQL Server授予了CREATE TABLE权限但是无法创建表

    2024-01-28 18:26:23
  • js验证表单(form)中的单选(radio)值

    2008-03-18 13:23:00
  • Mysql InnoDB多版本并发控制MVCC详解

    2024-01-23 16:46:25
  • Python PyQt5 Pycharm 环境搭建及配置详解(图文教程)

    2023-06-23 12:44:34
  • 微信小程序实现翻牌小功能

    2023-07-02 05:18:37
  • select下拉菜单实现二级联动效果

    2023-05-22 22:30:32
  • ChatGPT 帮我自动编写 Python 爬虫脚本的详细过程

    2021-09-09 09:13:50
  • 请谨慎对待程序的图标和名称

    2011-06-16 20:35:22
  • pymysql 插入数据 转义处理方式

    2024-01-23 08:43:29
  • Pygame游戏开发之太空射击实战图像精灵下篇

    2022-10-20 16:09:06
  • GoLang切片并发安全解决方案详解

    2024-05-09 09:54:15
  • 如何把Mysql卸载干净(亲测有效)

    2024-01-16 09:06:06
  • JS实现倒计时图文效果

    2024-04-28 09:48:28
  • Python实现识别图片内容的方法分析

    2022-01-04 21:10:51
  • 一篇文章带你了解清楚Mysql 锁

    2024-01-24 21:17:43
  • Python I/O与进程的详细讲解

    2022-11-27 14:07:15
  • Python实现按中文排序的方法示例

    2023-11-29 15:19:22
  • python udp如何实现同时收发信息

    2023-12-16 10:06:33
  • mysql 8.0.22压缩包完整安装与配置教程图解(亲测安装有效)

    2024-01-25 05:26:14
  • 解决MySQL报错:You can‘t specify target table ‘region‘ for update in FROM clause

    2024-01-27 08:02:39
  • asp之家 网络编程 m.aspxhome.com