pytorch深度神经网络入门准备自己的图片数据

作者:denny402 时间:2023-12-07 13:55:58 

图片数据一般有两种情况:

1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。

2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。

针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:

一、所有图片放在一个文件夹内

这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。

先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:

import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
   './mnist', train=False, download=True
)
print('test set:', len(mnist_test))
f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
   img_path="./mnist_test/"+str(i)+".jpg"
   io.imsave(img_path,img)
   f.write(img_path+' '+str(label)+'\n')
f.close()

经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:

pytorch深度神经网络入门准备自己的图片数据

前期工作就装备好了,接着就进入正题了:

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
def default_loader(path):
   return Image.open(path).convert('RGB')
class MyDataset(Dataset):
   def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
       fh = open(txt, 'r')
       imgs = []
       for line in fh:
           line = line.strip('\n')
           line = line.rstrip()
           words = line.split()
           imgs.append((words[0],int(words[1])))
       self.imgs = imgs
       self.transform = transform
       self.target_transform = target_transform
       self.loader = loader
   def __getitem__(self, index):
       fn, label = self.imgs[index]
       img = self.loader(fn)
       if self.transform is not None:
           img = self.transform(img)
       return img,label
   def __len__(self):
       return len(self.imgs)
train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
   grid = utils.make_grid(imgs)
   plt.imshow(grid.numpy().transpose((1, 2, 0)))
   plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
   if(i<4):
       print(i, batch_x.size(),batch_y.size())
       show_batch(batch_x)
       plt.axis('off')
       plt.show()

自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。

二、不同类别的图片放在不同的文件夹内

同样先准备数据,这里以flowers数据集为例

提取 链接: https://pan.baidu.com/s/1dcAsOOZpUfWNYR77JGXPHA?pwd=mwg6 

花总共有五类,分别放在5个文件夹下。大致如下图:

pytorch深度神经网络入门准备自己的图片数据

我的路径是d:/flowers/.

数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder

import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt
img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
                                           transform=transforms.Compose([
                                               transforms.Scale(256),
                                               transforms.CenterCrop(224),
                                               transforms.ToTensor()])
                                           )
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
   grid = utils.make_grid(imgs,nrow=5)
   plt.imshow(grid.numpy().transpose((1, 2, 0)))
   plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
   if(i<4):
       print(i, batch_x.size(), batch_y.size())
       show_batch(batch_x)
       plt.axis('off')
       plt.show()

来源:https://www.cnblogs.com/denny402/p/7512516.html

标签:pytorch,图片数据,数据准备,深度神经网络
0
投稿

猜你喜欢

  • PHP实现获取两个以逗号分割的字符串的并集

    2023-06-01 03:24:53
  • 使用PIL(Python-Imaging)反转图像的颜色方法

    2022-12-15 19:16:48
  • python利用proxybroker构建爬虫免费IP代理池的实现

    2021-10-25 21:18:25
  • Python Decorator的设计模式演绎过程解析

    2021-10-13 14:29:37
  • 编写兼容IE和FireFox的脚本

    2009-05-19 12:01:00
  • Python虚拟环境的创建和包下载过程分析

    2023-01-02 12:46:10
  • 利用Python编写简易版德州扑克小游戏

    2021-02-03 06:00:59
  • Python3+OpenCV实现简单交通标志识别流程分析

    2021-03-12 06:37:41
  • python3.x zip用法小结

    2023-08-13 05:25:05
  • asp 实现检测字符串是否为纯字母和数字组合的函数

    2009-10-04 20:39:00
  • Pygame改编飞机大战制作兔子接月饼游戏

    2023-04-09 02:57:22
  • Bootstrap实现提示框和弹出框效果

    2023-07-02 05:25:33
  • pytest中文文档之编写断言

    2023-05-05 04:11:34
  • Python创建增量目录的代码实例

    2021-12-07 04:12:55
  • python中文编码与json中文输出问题详解

    2021-03-15 17:57:18
  • 批处理写的 oracle 数据库备份还原工具

    2024-01-25 06:32:27
  • Python编程pygame模块实现移动的小车示例代码

    2021-04-13 10:16:08
  • 简单的Python2.7编程初学经验总结

    2021-03-18 01:27:11
  • python利用datetime模块计算时间差

    2021-10-07 01:02:04
  • Python MySQL数据库连接池组件pymysqlpool详解

    2024-01-22 23:59:17
  • asp之家 网络编程 m.aspxhome.com