pytorch 带batch的tensor类型图像显示操作

作者:Xavier Jiezou 时间:2023-06-02 08:47:26 

项目场景

pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。

那么如何显示dataloader里面带batch的tensor类型的图像呢?

显示图像

绘图最常用的库就是matplotlib:


pip install matplotlib

显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:

pytorch 带batch的tensor类型图像显示操作

数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)

用法示例如下:


>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])

代码示例


#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
   root='./dataset/',
   train=True,
   transform=transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.1307,), (0.3081,))
   ]),
   download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
   dataset=train_file,
   batch_size=9,
   shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
   plt.subplot(3, 3, i+1)
   plt.title(labels[i].item())
   plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
   plt.axis('off')
plt.show()

这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。

所以,如果你想查看训练集的原始图像,还得反标准化。

标准化:image = (image-mean)/std

反标准化:image = image*std+mean

我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:

pytorch 带batch的tensor类型图像显示操作

最终效果

pytorch 带batch的tensor类型图像显示操作

补充:PIL,plt显示tensor类型的图像

该方法针对显示Dataloader读取的图像

PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。


# 方法1:Image.show()
# transforms.ToPILImage()中有一句
# npimg = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
img = transforms.ToPILImage(image[0])
img.show()

# 方法2:plt.imshow(ndarray)
img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
img = img.numpy() # FloatTensor转为ndarray
img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
# 显示图片
plt.imshow(img)
plt.show()
cnt += 1

来源:https://blog.csdn.net/qq_42951560/article/details/109962828

标签:pytorch,batch,tensor,图像
0
投稿

猜你喜欢

  • Python计算斗牛游戏概率算法实例分析

    2021-08-08 09:52:21
  • python中的字符串切割 maxsplit

    2022-04-16 14:35:35
  • vue单页应用中如何使用jquery的方法示例

    2024-05-09 10:40:14
  • php引用和拷贝的区别知识点总结

    2023-11-15 03:39:48
  • [整理版]防止Access数据库被下载的9种方法

    2007-08-10 09:31:00
  • Python运算符的应用超全面详细教程

    2023-08-20 18:24:56
  • Python编程pytorch深度卷积神经网络AlexNet详解

    2022-02-18 10:28:40
  • 简单介绍Python中的readline()方法的使用

    2023-11-02 13:34:30
  • JS判断浏览器类型与版本的实现代码

    2024-05-13 10:36:32
  • Bootstrap3制作自己的导航栏

    2023-08-23 02:13:08
  • Go语言底层原理互斥锁的实现原理

    2024-04-25 15:00:24
  • Mysql中Join的使用实例详解

    2024-01-26 05:04:36
  • SQL Server 2012 安装与启动图文教程

    2024-01-27 08:33:35
  • 详解JavaScript 中的批处理和缓存

    2024-04-28 09:48:03
  • 在Python的Django框架中编写编译函数

    2022-01-04 16:53:34
  • Python连接phoenix的方法示例

    2023-05-24 06:25:19
  • Windows存储 SQL行溢出 差异备份及疑问

    2008-12-24 15:22:00
  • Python设置Word全局样式和文本样式的示例代码

    2022-06-29 05:06:07
  • 总结Python连接CS2000的详细步骤

    2023-04-21 20:26:33
  • Vue中利用better-scroll组件实现横向滚动功能

    2024-05-09 15:28:29
  • asp之家 网络编程 m.aspxhome.com