pytorch 可视化feature map的示例代码
作者:牛丸4 时间:2021-10-21 13:35:49
之前做的一些项目中涉及到feature map 可视化的问题,一个层中feature map的数量往往就是当前层out_channels的值,我们可以通过以下代码可视化自己网络中某层的feature map,个人感觉可视化feature map对调参还是很有用的。
不多说了,直接看代码:
import torch
from torch.autograd import Variable
import torch.nn as nn
import pickle
from sys import path
path.append('/residual model path')
import residual_model
from residual_model import Residual_Model
model = Residual_Model()
model.load_state_dict(torch.load('./model.pkl'))
class myNet(nn.Module):
def __init__(self,pretrained_model,layers):
super(myNet,self).__init__()
self.net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
self.net2 = nn.Sequential(*list(pretrained_model.children())[:layers[1]])
self.net3 = nn.Sequential(*list(pretrained_model.children())[:layers[2]])
def forward(self,x):
out1 = self.net1(x)
out2 = self.net(out1)
out3 = self.net(out2)
return out1,out2,out3
def get_features(pretrained_model, x, layers = [3, 4, 9]): ## get_features 其实很简单
'''
1.首先import model
2.将weights load 进model
3.熟悉model的每一层的位置,提前知道要输出feature map的网络层是处于网络的那一层
4.直接将test_x输入网络,*list(model.chidren())是用来提取网络的每一层的结构的。net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]]) ,就是第三层前的所有层。
'''
net1 = nn.Sequential(*list(pretrained_model.children())[:layers[0]])
# print net1
out1 = net1(x)
net2 = nn.Sequential(*list(pretrained_model.children())[layers[0]:layers[1]])
# print net2
out2 = net2(out1)
#net3 = nn.Sequential(*list(pretrained_model.children())[layers[1]:layers[2]])
#out3 = net3(out2)
return out1, out2
with open('test.pickle','rb') as f:
data = pickle.load(f)
x = data['test_mains'][0]
x = Variable(torch.from_numpy(x)).view(1,1,128,1) ## test_x必须为Varibable
#x = Variable(torch.randn(1,1,128,1))
if torch.cuda.is_available():
x = x.cuda() # 如果模型的训练是用cuda加速的话,输入的变量也必须是cuda加速的,两个必须是对应的,网络的参数weight都是用cuda加速的,不然会报错
model = model.cuda()
output1,output2 = get_features(model,x)## model是训练好的model,前面已经import 进来了Residual model
print('output1.shape:',output1.shape)
print('output2.shape:',output2.shape)
#print('output3.shape:',output3.shape)
output_1 = torch.squeeze(output2,dim = 0)
output_1_arr = output_1.data.cpu().numpy() # 得到的cuda加速的输出不能直接转变成numpy格式的,当时根据报错的信息首先将变量转换为cpu的,然后转换为numpy的格式
output_1_arr = output_1_arr.reshape([output_1_arr.shape[0],output_1_arr.shape[1]])
来源:https://blog.csdn.net/baidu_36161077/article/details/81388221
标签:pytorch,可视化,feature,map
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
简单了解python 邮件模块的使用方法
2021-07-11 00:08:14
Python学习之configparser模块的使用详解
2022-07-21 23:21:25
![](https://img.aspxhome.com/file/2023/0/66380_0s.png)
python算法学习之桶排序算法实例(分块排序)
2022-09-08 13:11:33
python贪吃蛇核心功能实现上
2021-12-06 15:49:18
![](https://img.aspxhome.com/file/2023/6/67866_0s.png)
Python Tensor FLow简单使用方法实例详解
2022-01-01 16:55:44
![](https://img.aspxhome.com/file/2023/2/109562_0s.png)
P3P 和 跨域 (cross-domain) cookie 访问(读取和设置)
2011-04-02 10:42:00
ASP获取刚插入记录的自动编号ID
2008-11-17 20:41:00
Python调整matplotlib图片大小的3种方法汇总
2023-11-28 13:20:26
![](https://img.aspxhome.com/file/2023/9/92429_0s.png)
解决jupyter notebook import error但是命令提示符import正常的问题
2022-08-19 22:10:30
Django-Scrapy生成后端json接口的方法示例
2021-07-16 18:46:46
![](https://img.aspxhome.com/file/2023/4/97094_0s.png)
使用python检查值是否已经存在于字典列表中
2023-10-25 03:08:21
![](https://img.aspxhome.com/file/2023/1/84951_0s.png)
PJBlog3优化——301定向跳转解决重复内容的问题
2009-05-20 10:40:00
Python 中的Sympy详细使用
2021-10-03 03:22:45
![](https://img.aspxhome.com/file/2023/2/80782_0s.png)
用python做游戏的细节详解
2022-02-08 05:18:39
Python详细讲解图像处理的而两种库OpenCV和Pillow
2022-08-14 05:23:19
![](https://img.aspxhome.com/file/2023/3/79493_0s.png)
SQL Server技巧之快速得到表的记录总数
2011-01-04 14:36:00
![](https://img.aspxhome.com/file/UploadPic/20111/20111420216938s.jpg)
Python 中如何使用 virtualenv 管理虚拟环境
2022-02-20 00:57:44
![](https://img.aspxhome.com/file/2023/4/68274_0s.png)
PHP PDOStatement::rowCount讲解
2023-06-06 12:24:04
给在DreamWeaver编写CSS的人一些习惯建议
2007-12-25 12:10:00
django中模板的html自动转意方法
2023-06-28 15:33:49