利用Pytorch实现获取特征图的方法详解

作者:拜阳 时间:2023-09-11 16:16:02 

简单加载官方预训练模型

torchvision.models预定义了很多公开的模型结构

如果pretrained参数设置为False,那么仅仅设定模型结构;如果设置为True,那么会启动一个下载流程,下载预训练参数

如果只想调用模型,不想训练,那么设置model.eval()和model.requires_grad_(False)

想查看模型参数可以使用modules和named_modules,其中named_modules是一个长度为2的tuple,第一个变量是name,第二个变量是module本身。

# -*- coding: utf-8 -*-
from torch import nn
from torchvision import models

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model.eval()
model.requires_grad_(False)

# get model component
features = model.features
modules = features.modules()
named_modules = features.named_modules()

# print modules
for module in modules:
   if isinstance(module, nn.Conv2d):
       weight = module.weight
       bias = module.bias
       print(module, weight.shape, bias.shape,
             weight.requires_grad, bias.requires_grad)
   elif isinstance(module, nn.ReLU):
       print(module)

print()
for named_module in named_modules:
   name = named_module[0]
   module = named_module[1]
   if isinstance(module, nn.Conv2d):
       weight = module.weight
       bias = module.bias
       print(name, module, weight.shape, bias.shape,
             weight.requires_grad, bias.requires_grad)
   elif isinstance(module, nn.ReLU):
       print(name, module)

图片预处理

使用opencv和pil读图都可以使用transforms.ToTensor()把原本[H, W, 3]的数据转成[3, H, W]的tensor。但opencv要注意把数据改成RGB顺序。

vgg系列模型需要做normalization,建议配合torchvision.transforms来实现。

mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

参考:https://pytorch.org/hub/pytorch_vision_vgg/

# -*- coding: utf-8 -*-
from PIL import Image
import cv2
import torch
from torchvision import transforms

# transforms for preprocess
preprocess = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# load image using cv2
image_cv2 = cv2.imread('lena_std.bmp')
image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)
image_cv2 = preprocess(image_cv2)

# load image using pil
image_pil = Image.open('lena_std.bmp')
image_pil = preprocess(image_pil)

# check whether image_cv2 and image_pil are same
print(torch.all(image_cv2 == image_pil))
print(image_cv2.shape, image_pil.shape)

提取单个特征图

如果只提取单层特征图,可以把模型截断,以节省算力和显存消耗。

下面索引之所以有+1是因为pytorch预训练模型里面第一个索引的module总是完整模块结构,第二个才开始子模块。

# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1]  # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')
print(model)

# load and preprocess image
preprocess = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
   transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

# forward
output = model(inputs)
print(output.shape)

提取多个特征图

第一种方式:逐层运行model,如果碰到了需要保存的feature map就存下来。

第二种方式:使用register_forward_hook,使用这种方式需要用一个类把feature map以成员变量的形式缓存下来。

两种方式的运行效率差不多

第一种方式简单直观,但是只能处理类似VGG这种没有跨层连接的网络;第二种方式更加通用。

# -*- coding: utf-8 -*-
from PIL import Image
import torch
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.vgg19(pretrained=True)
model = model.features[:16 + 1]  # 16 = conv3_4
model.eval()
model.requires_grad_(False)
model.to('cuda')

# check module name
for named_module in model.named_modules():
   name = named_module[0]
   module = named_module[1]
   print('-------- %s --------' % name)
   print(module)
   print()

# load and preprocess image
preprocess = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
   transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

# forward - 1
layers = [2, 7, 8, 9, 16]
layers = sorted(set(layers))
feature_maps = {}
feature = inputs
for i in range(max(layers) + 1):
   feature = model[i](feature)
   if i in layers:
       feature_maps[i] = feature
for key in feature_maps:
   print(key, feature_maps.get(key).shape)

# forward - 2
class FeatureHook:
   def __init__(self, module):
       self.inputs = None
       self.output = None
       self.hook = module.register_forward_hook(self.get_features)

def get_features(self, module, inputs, output):
       self.inputs = inputs
       self.output = output

layer_names = ['2', '7', '8', '9', '16']
hook_modules = []
for named_module in model.named_modules():
   name = named_module[0]
   module = named_module[1]
   if name in layer_names:
       hook_modules.append(module)

hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
   print(feature.shape)

# check correctness
for i, layer in enumerate(layers):
   feature1 = feature_maps.get(layer)
   feature2 = features[i]
   print(torch.all(feature1 == feature2))

使用第二种方式(register_forward_hook),resnet特征图也可以顺利拿到。

而由于resnet的model已经不可以用model[i]的形式索引,所以无法使用第一种方式。

# -*- coding: utf-8 -*-
from PIL import Image
from torchvision import models
from torchvision import transforms

# load model. If pretrained is True, there will be a downloading process
model = models.resnet18(pretrained=True)
model.eval()
model.requires_grad_(False)
model.to('cuda')

# check module name
for named_module in model.named_modules():
   name = named_module[0]
   module = named_module[1]
   print('-------- %s --------' % name)
   print(module)
   print()

# load and preprocess image
preprocess = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
   transforms.Resize(size=(224, 224))
])
image = Image.open('lena_std.bmp')
image = preprocess(image)
inputs = image.unsqueeze(0)  # add batch dimension
inputs = inputs.cuda()

class FeatureHook:
   def __init__(self, module):
       self.inputs = None
       self.output = None
       self.hook = module.register_forward_hook(self.get_features)

def get_features(self, module, inputs, output):
       self.inputs = inputs
       self.output = output

layer_names = [
   'conv1',
   'layer1.0.relu',
   'layer2.0.conv1'
]

hook_modules = []
for named_module in model.named_modules():
   name = named_module[0]
   module = named_module[1]
   if name in layer_names:
       hook_modules.append(module)

hooks = [FeatureHook(module) for module in hook_modules]
output = model(inputs)
features = [hook.output for hook in hooks]
for feature in features:
   print(feature.shape)

问题来了,resnet这种类型的网络结构怎么截断?

使用如下命令就可以,print查看需要截断到哪里,然后用nn.Sequential重组即可。

需注意重组后网络的module_name会发生变化。

print(list(model.children())
model = torch.nn.Sequential(*list(model.children())[:6])

来源:https://blog.csdn.net/bby1987/article/details/126636160

标签:Pytorch,特征图
0
投稿

猜你喜欢

  • Python+Pyecharts实现散点图的绘制

    2023-09-02 05:01:29
  • python列表倒序的几种方法(切片、reverse()、reversed())

    2022-01-28 02:46:52
  • Chrome调试折腾记之JS断点调试技巧

    2023-07-07 16:35:08
  • pyqt5利用pyqtDesigner实现登录界面

    2023-09-04 15:19:36
  • Python正则表达式的应用详解

    2023-07-29 21:58:51
  • ASP编程菜鸟易犯的一个错误

    2008-10-29 13:27:00
  • keras自定义损失函数并且模型加载的写法介绍

    2023-06-23 04:00:30
  • pyhton列表转换为数组的实例

    2021-01-12 08:14:31
  • 算法系列15天速成 第九天 队列

    2022-06-29 14:16:43
  • 原生Javascript插件开发实践

    2024-04-17 09:43:45
  • Frontpage中网页字体的美化研究

    2008-03-10 12:13:00
  • Python错误+异常+模块总结

    2023-07-26 03:18:19
  • 5道关于python基础 while循环练习题

    2023-05-22 02:03:33
  • Python常见数字运算操作实例小结

    2022-01-02 21:04:16
  • python生成并处理uuid的实现方式

    2022-11-18 12:37:26
  • django数据库迁移migration实现

    2024-01-12 19:29:43
  • Flask 入门系列 Cookie与session的介绍

    2022-06-21 00:45:44
  • Python 从相对路径下import的方法

    2023-06-15 03:16:10
  • 常见的python正则用法实例讲解

    2023-03-11 23:11:29
  • Mysql的最大连接数怎样用java程序测试

    2009-01-14 12:05:00
  • asp之家 网络编程 m.aspxhome.com