使用pytorch提取卷积神经网络的特征图可视化

作者:落樱弥城 时间:2023-02-01 20:32:30 

前言

文章中的代码是参考基于Pytorch的特征图提取编写的代码本身很简单这里只做简单的描述。

1. 效果图

先看效果图(第一张是原图,后面的都是相应的特征图,这里使用的网络是resnet50,需要注意的是下面图片显示的特征图是经过放大后的图,原图是比较小的图,因为太小不利于我们观察):

使用pytorch提取卷积神经网络的特征图可视化

使用pytorch提取卷积神经网络的特征图可视化

使用pytorch提取卷积神经网络的特征图可视化

2. 完整代码

import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
import skimage.data
import skimage.io
import skimage.transform
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
from PIL import Image
import cv2

class FeatureExtractor(nn.Module):
   def __init__(self, submodule, extracted_layers):
       super(FeatureExtractor, self).__init__()
       self.submodule = submodule
       self.extracted_layers = extracted_layers

def forward(self, x):
       outputs = {}
       for name, module in self.submodule._modules.items():
           if "fc" in name:
               x = x.view(x.size(0), -1)

x = module(x)
           print(name)
           if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
               outputs[name] = x

return outputs

def get_picture(pic_name, transform):
   img = skimage.io.imread(pic_name)
   img = skimage.transform.resize(img, (256, 256))
   img = np.asarray(img, dtype=np.float32)
   return transform(img)

def make_dirs(path):
   if os.path.exists(path) is False:
       os.makedirs(path)

def get_feature():
   pic_dir = './images/2.jpg'
   transform = transforms.ToTensor()
   img = get_picture(pic_dir, transform)
   device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
   # 插入维度
   img = img.unsqueeze(0)

img = img.to(device)

net = models.resnet101().to(device)
   net.load_state_dict(torch.load('./model/resnet101-5d3b4d8f.pt'))
   exact_list = None
   dst = './feautures'
   therd_size = 256

myexactor = FeatureExtractor(net, exact_list)
   outs = myexactor(img)
   for k, v in outs.items():
       features = v[0]
       iter_range = features.shape[0]
       for i in range(iter_range):
           #plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
           if 'fc' in k:
               continue

feature = features.data.numpy()
           feature_img = feature[i,:,:]
           feature_img = np.asarray(feature_img * 255, dtype=np.uint8)

dst_path = os.path.join(dst, k)

make_dirs(dst_path)
           feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
           if feature_img.shape[0] < therd_size:
               tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
               tmp_img = feature_img.copy()
               tmp_img = cv2.resize(tmp_img, (therd_size,therd_size), interpolation =  cv2.INTER_NEAREST)
               cv2.imwrite(tmp_file, tmp_img)

dst_file = os.path.join(dst_path, str(i) + '.png')
           cv2.imwrite(dst_file, feature_img)

if __name__ == '__main__':
   get_feature()

3. 代码说明

下面的模块是根据所指定的模型筛选出指定层的特征图输出,如果未指定也就是extracted_layers是None则以字典的形式输出全部的特征图,另外因为全连接层本身是一维的没必要输出因此进行了过滤。

class FeatureExtractor(nn.Module):
   def __init__(self, submodule, extracted_layers):
       super(FeatureExtractor, self).__init__()
       self.submodule = submodule
       self.extracted_layers = extracted_layers

def forward(self, x):
       outputs = {}
       for name, module in self.submodule._modules.items():
           if "fc" in name:
               x = x.view(x.size(0), -1)

x = module(x)
           print(name)
           if self.extracted_layers is None or name in self.extracted_layers and 'fc' not in name:
               outputs[name] = x

return outputs

这段主要是存储图片,为每个层创建一个文件夹将特征图以JET的colormap进行按顺序存储到该文件夹,并且如果特征图过小也会对特征图放大同时存储原始图和放大后的图。

for k, v in outs.items():
       features = v[0]
       iter_range = features.shape[0]
       for i in range(iter_range):
           #plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
           if 'fc' in k:
               continue

feature = features.data.numpy()
           feature_img = feature[i,:,:]
           feature_img = np.asarray(feature_img * 255, dtype=np.uint8)

dst_path = os.path.join(dst, k)

make_dirs(dst_path)
           feature_img = cv2.applyColorMap(feature_img, cv2.COLORMAP_JET)
           if feature_img.shape[0] < therd_size:
               tmp_file = os.path.join(dst_path, str(i) + '_' + str(therd_size) + '.png')
               tmp_img = feature_img.copy()
               tmp_img = cv2.resize(tmp_img, (therd_size,therd_size), interpolation =  cv2.INTER_NEAREST)
               cv2.imwrite(tmp_file, tmp_img)

dst_file = os.path.join(dst_path, str(i) + '.png')
           cv2.imwrite(dst_file, feature_img)

这里主要是一些参数,比如要提取的网络,网络的权重,要提取的层,指定的图像放大的大小,存储路径等等。

net = models.resnet101().to(device)
   net.load_state_dict(torch.load('./model/resnet101-5d3b4d8f.pt'))
   exact_list = None#['conv1']
   dst = './feautures'
   therd_size = 256

4. 可视化梯度,feature

上面的办法只是简单的将经过网络计算的图片的输出的feature进行图片,github上有将CNN的梯度等全部进行可视化的代码:pytorch-cnn-visualizations,需要注意的是如果只是简单的替换成自己的网络可能无法运行,大概率会报model没有features或者classifier等错误,这两个是进行分类网络定义时的Sequential,其实就是索引网络的每一层,自己稍微修改用model.children()等方法进行替换即可,我自己修改之后得到的代码grayondream-pytorch-visualization(本来想稍微封装一下成为一个更加通用的结构,暂时没时间以后再说吧!),下面是效果图:

使用pytorch提取卷积神经网络的特征图可视化

使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化使用pytorch提取卷积神经网络的特征图可视化

来源:https://blog.csdn.net/GrayOnDream/article/details/99090247

标签:pytorch,特征图,可视化
0
投稿

猜你喜欢

  • SQL Server索引超出了数组界限的解决方案

    2024-01-12 19:14:41
  • 用Python实现等级划分

    2022-10-11 23:45:21
  • python3连接MySQL8.0的两种方式

    2024-01-20 20:16:14
  • 点选TOP后并不是直接跳到页顶的,而是滚动上去

    2023-09-07 02:36:43
  • 五种提高 SQL 性能的方法

    2008-05-16 10:40:00
  • 如何使用sql语句来修改数据记录

    2007-06-21 11:48:00
  • SQL Server 2000 作数据库服务器的优点

    2009-01-23 13:47:00
  • 实践Python的爬虫框架Scrapy来抓取豆瓣电影TOP250

    2021-04-26 21:27:11
  • ASP调用系统ping命令代码

    2008-04-27 20:45:00
  • 在ASP中使用SQL语句之11:记录统计

    2007-08-11 13:27:00
  • 某年第一周开始日期sql实现方法

    2012-02-25 20:02:30
  • python 获取等间隔的数组实例

    2023-05-21 15:07:16
  • 使用python的chardet库获得文件编码并修改编码

    2022-02-23 18:22:35
  • Python3之乱码\\xe6\\x97\\xa0\\xe6\\xb3\\x95处理方式

    2021-03-30 10:19:47
  • python连接mongodb操作数据示例(mongodb数据库配置类)

    2023-05-01 17:21:17
  • Blender Python编程实现批量导入网格并保存渲染图像

    2021-06-09 04:21:20
  • Mac 安装和卸载 Mysql5.7.11 的方法

    2024-01-23 16:09:58
  • 使用Python第三方库pygame写个贪吃蛇小游戏

    2021-05-19 11:08:37
  • JS实现两周内自动登录功能

    2023-08-04 21:20:57
  • Python代码阅读--列表元素逻辑判断

    2022-08-05 16:12:57
  • asp之家 网络编程 m.aspxhome.com