详解PyTorch预定义数据集类datasets.ImageFolder使用方法

作者:实力 时间:2022-01-30 20:15:07 

datasets.ImageFolder是PyTorch提供的一个预定义数据集类,用于处理图像数据。它可以方便地将一组图像加载到内存中,并为每个图像分配标签。

数据集准备和目录结构

要使用datasets.ImageFolder,我们需要准备好一个包含图像数据的目录,并按照以下方式进行组织:

root/
   class1/
       img1.jpg
       img2.jpg
       ...
   class2/
       img1.jpg
       img2.jpg
       ...
   ...

其中,root代表数据集根目录,class1、class2等代表不同的分类标签,img1、img2等代表图像文件名。每个类别(也称为标签)应该有一个单独的子目录,子目录中包含这个类别的所有图像文件。同时,每个图像文件在对应的子目录下,以其文件名作为其类别标签。这种目录组织方式可以让我们轻松获取图像和对应的标签信息。

加载数据集

完成数据集准备之后,我们就可以使用datasets.ImageFolder来加载它了。下面是一个示例代码:

import torchvision.datasets as datasets
import torchvision.transforms as transforms
data_dir = "/path/to/data"
transforms = transforms.Compose([
   transforms.Resize(size=(224, 224)),
   transforms.ToTensor(),
])
dataset = datasets.ImageFolder(root=data_dir, transform=transforms)

在这个例子中,我们首先导入datasets和transforms模块,然后指定数据集的根目录data_dir。接下来,我们定义一个 transforms 对象,它将图像转换为PyTorch张量,并调整大小为(224, 224)。

最后,我们使用datasets.ImageFolder来加载图像数据集。ImageFolder类需要两个参数:root 和 transform。root是数据集根目录;transform指定对每个图像应该执行的预处理操作,例如调整大小、裁剪、翻转等。

数据集划分

对于机器学习任务,我们通常需要将数据集划分成训练集、验证集和测试集。在PyTorch中,我们可以使用torch.utils.data.random_split函数来完成数据集的划分。下面是一个示例代码:

from torch.utils.data import DataLoader, random_split
# Split the dataset into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# Split train dataset into train and validation sets
val_size = int(0.2 * len(train_dataset))
train_size = len(train_dataset) - val_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

在这个例子中,我们先使用random_split函数将原始数据集划分为训练集和测试集,在这里80%的数据用于训练,20%的数据用于测试。然后,我们再次使用random_split函数将训练集划分为训练集和验证集,其中80%的数据用于训练,20%的数据用于验证。

数据加载器

最后,我们可以使用数据加载器(DataLoader)来加载数据集。数据加载器负责将图像数据和标签封装成批量,并提供多线程方式加载数据以加速训练过程。下面是一个示例代码:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

在这里,我们创建了三个数据加载器train_loader、val_loader 和 test_loader,它们分别对应训练集、验证集和测试集。batch_size参数指定了每个批次的大小,shuffle参数表示是否随机化输入数据(在训练集中设置为True,在验证集和测试集中设置为False)。

来源:https://juejin.cn/post/7223988948069302329

标签:PyTorch,datasets.ImageFolder,预定义,数据集类
0
投稿

猜你喜欢

  • OverFlow:一个秘密武器

    2011-02-26 15:41:00
  • ORACLE数据库查看执行计划的方法

    2012-06-06 20:15:52
  • 如何让新页面在新窗口打开?

    2009-04-12 19:41:00
  • python回调函数中使用多线程的方法

    2022-11-08 20:01:15
  • 将HTML表单数据存储为XML格式

    2007-08-23 13:04:00
  • 利用Python如何实现K-means聚类算法

    2023-09-16 09:17:38
  • python async with和async for的使用

    2021-10-16 16:31:02
  • python 密码验证(滑块验证)

    2021-01-24 02:32:18
  • Tensorflow训练模型越来越慢的2种解决方案

    2021-06-04 20:55:53
  • 关于Python的高级数据结构与算法

    2023-12-22 00:29:40
  • python实现简单图片物体标注工具

    2021-09-07 21:31:50
  • 30分钟就入门的正则表达式基础教程

    2024-05-13 10:37:40
  • python使用JSON模块进行数据处理(编码解码)

    2024-01-01 21:52:42
  • PHP延迟静态绑定的深入讲解

    2024-06-05 15:42:51
  • 前端来看看 maxthon bugs

    2008-09-23 18:35:00
  • mysql全文搜索 sql命令的写法

    2024-01-25 04:45:38
  • 详解OpenCV中直方图,掩膜和直方图均衡化的实现

    2022-10-30 12:03:33
  • Python数据提取-lxml模块

    2022-04-03 15:15:19
  • MySQL8 批量修改字符集脚本

    2024-01-16 12:50:34
  • 利用Python的Django框架中的ORM建立查询API

    2023-11-15 10:06:03
  • asp之家 网络编程 m.aspxhome.com