PyTorch实现重写/改写Dataset并载入Dataloader

作者:全员鳄鱼 时间:2023-10-31 17:19:35 

前言

众所周知,Dataset和Dataloder是pytorch中进行数据载入的部件。必须将数据载入后,再进行深度学习模型的训练。在pytorch的一些案例教学中,常使用torchvision.datasets自带的MNIST、CIFAR-10数据集,一般流程为:


# 下载并存放数据集
train_dataset = torchvision.datasets.CIFAR10(root="数据集存放位置",download=True)
# load数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset)

但是,在我们自己的模型训练中,需要使用非官方自制的数据集。这时应该怎么办呢?

我们可以通过改写torch.utils.data.Dataset中的__getitem____len__来载入我们自己的数据集。
__getitem__获取数据集中的数据,__len__获取整个数据集的长度(即个数)。

改写

采用pytorch官网案例中提供的一个脸部landmark数据集。数据集中含有存放landmark的csv文件,但是我们在这篇文章中不使用(其实也可以随便下载一些图片作数据集来实验)。


import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

plt.ion()  # interactive mode

torch.utils.data.Dataset是一个抽象类,我们自己的数据集需要继承Dataset,然后改写上述两个函数:


class ImageLoader(Dataset):
 def __init__(self, file_path, transform=None):
   super(ImageLoader,self).__init__()
   self.file_path = file_path
   self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
   self.image_names = os.listdir(self.file_path) # 文件名的列表

def __getitem__(self,idx):
   image = self.image_names[idx]
   image = io.imread(os.path.join(self.file_path,image))
#    if self.transform:
#    image= self.transform(image)
   return image

def __len__(self):
   return len(self.image_names)

# 设置自己存放的数据集位置,并plot展示    
imageloader = ImageLoader(file_path="D:\\Projects\\datasets\\faces\\")
# imageloader.__len__()       # 输出数据集长度(个数),应为71
# print(imageloader.__getitem__(0)) # 以数据形式展示
plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
plt.show()

得到的图片输出:

PyTorch实现重写/改写Dataset并载入Dataloader

得到的数据输出,:


array([[[ 66, 59, 53],
   [ 66, 59, 53],
   [ 66, 59, 53],
   ...,
   [ 59, 54, 48],
   [ 59, 54, 48],
   [ 59, 54, 48]],
   ...,
   [153, 141, 129],
   [158, 146, 134],
   [158, 146, 134]]], dtype=uint8)

上面看到dytpe=uint8,实际进行训练的时候,常常需要更改成float的数据类型。可以使用:


# 直接改成pytorch中的tensor下的float格式
# 也可以用numpy的改成普通的float格式
to_float= torch.from_numpy(imageloader.__getitem__(0)).float()

改写完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)载入到Dataloader中,就可以使用了。
下面的代码可以试着运行一下,产生的是一模一样的图片结果。


train_loader = torch.utils.data.DataLoader(dataset=imageloader)
train_loader.dataset[0]
plt.imshow(train_loader.dataset[0])
plt.show()

来源:https://blog.csdn.net/qq_38372240/article/details/107322677

标签:PyTorch,重写,改写,Dataset
0
投稿

猜你喜欢

  • Linux系统(CentOS)下python2.7.10安装

    2021-04-02 19:27:50
  • PHP5.6读写excel表格文件操作示例

    2023-11-21 15:03:21
  • python去除空格,tab制表符和\\n换行符的小技巧分享

    2022-05-12 14:20:39
  • PHP attributes()函数讲解

    2023-06-04 09:33:02
  • Python-OpenCV深度学习入门示例详解

    2022-07-24 02:44:24
  • Python CSS选择器爬取京东网商品信息过程解析

    2022-01-17 21:18:17
  • accept-charset与Header P3P

    2009-04-01 18:43:00
  • Python 中Django验证码功能的实现代码

    2022-05-01 22:55:39
  • SQLServer 连接失败错误故障的分析与排除

    2024-01-24 09:09:42
  • MySQL一个索引最多有多少个列?真实的测试例子

    2024-01-20 18:32:39
  • PHP获取url的函数代码

    2023-10-15 12:45:00
  • python环形单链表的约瑟夫问题详解

    2023-03-02 04:13:10
  • Python图像处理之膨胀与腐蚀的操作

    2022-10-07 19:47:06
  • python实现弹窗祝福效果

    2021-09-08 04:22:15
  • Python实现命令行通讯录实例教程

    2023-10-18 01:51:28
  • asp简单的仿图片验证码

    2008-03-12 11:54:00
  • pyinstaller还原python代码过程图解

    2022-04-09 10:06:59
  • Ubuntu权限不足无法创建文件夹解决方案

    2021-04-06 01:31:27
  • 解决pycharm 工具栏Tool中找不到Run manager.py Task的问题

    2023-02-26 09:04:31
  • JavaScript来实现打开链接页面的简单实例

    2024-04-30 09:51:32
  • asp之家 网络编程 m.aspxhome.com