Pytorch数据读取之Dataset和DataLoader知识总结

作者:群星闪耀 时间:2023-11-02 22:57:37 

一、前言

确保安装

  • scikit-image

  • numpy

二、Dataset

一个例子:


# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np

# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]

#创建子类
class subDataset(Dataset.Dataset):
   #初始化,定义数据内容和标签
   def __init__(self, Data, Label):
       self.Data = Data
       self.Label = Label
   #返回数据集大小
   def __len__(self):
       return len(self.Data)
   #得到数据内容和标签
   def __getitem__(self, index):
       data = torch.Tensor(self.Data[index])
       label = torch.IntTensor(self.Label[index])
       return data, label

# 主函数
if __name__ == '__main__':
   dataset = subDataset(Data, Label)
   print(dataset)
   print('dataset大小为:', dataset.__len__())
   print(dataset.__getitem__(0))
   print(dataset[0])

 输出的结果

Pytorch数据读取之Dataset和DataLoader知识总结

我们有了对Dataset的一个整体的把握,再来分析里面的细节:


#创建子类
class subDataset(Dataset.Dataset):

创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!

lengetitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。

三、DatasetLoader


DataLoader(dataset,
          batch_size=1,
          shuffle=False,
          sampler=None,
          batch_sampler=None,
          num_works=0,
          clollate_fn=None,
          pin_memory=False,
          drop_last=False,
          timeout=0,
          worker_init_fn=None,
          multiprocessing_context=None)

功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;

还是举一个实例:


import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np

Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
   #初始化,定义数据内容和标签
   def __init__(self, Data, Label):
       self.Data = Data
       self.Label = Label
   #返回数据集大小
   def __len__(self):
       return len(self.Data)
   #得到数据内容和标签
   def __getitem__(self, index):
       data = torch.Tensor(self.Data[index])
       label = torch.IntTensor(self.Label[index])
       return data, label

if __name__ == '__main__':
   dataset = subDataset(Data, Label)
   print(dataset)
   print('dataset大小为:', dataset.__len__())
   print(dataset.__getitem__(0))
   print(dataset[0])

#创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程

dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
   for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
       print('i:', i)
       data, label = item #数据是一个元组
       print('data:', data)
       print('label:', label)

四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)

这部分可以直接去看博客:Dataset和DataLoader

总结下来时有两种方法解决

1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。

2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。

这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。

来源:https://blog.csdn.net/weixin_40244676/article/details/117043973

标签:Pytorch,Dataset,DataLoader
0
投稿

猜你喜欢

  • python编程开发时间序列calendar模块示例详解

    2023-04-25 14:59:19
  • MYSQL 批量替换之replace语法的使用详解

    2024-01-21 19:52:35
  • python opencv通过4坐标剪裁图片

    2022-06-03 20:14:03
  • Python 一句话生成字母表的方法

    2022-03-15 06:49:47
  • vue.js实现日历插件使用方法详解

    2024-05-13 09:38:43
  • 支持png透明图片的php生成缩略图类分享

    2023-11-18 07:26:13
  • OpenCV实现背景分离(证件照背景替换)

    2023-04-18 19:41:22
  • 深入了解vue2与vue3的生命周期对比

    2024-05-11 09:14:32
  • Python 聊聊socket中的listen()参数(数字)到底代表什么

    2022-10-17 00:49:25
  • 用CSS实现柱状图(Bar Graph)的方法(四)—table实现复杂柱状图

    2008-05-28 12:55:00
  • Python读取yaml文件的详细教程

    2021-03-16 20:43:27
  • 玩转CSS3色彩[译]

    2010-01-13 13:02:00
  • SqlServer 2005 T-SQL Query 学习笔记(1)

    2024-01-25 17:01:56
  • asp + oracle 分页方法

    2010-05-11 20:09:00
  • Go语言异常处理案例解析

    2024-02-04 07:26:02
  • python 同时运行多个程序的实例

    2021-03-25 07:36:20
  • Python列表的定义及使用

    2023-08-02 03:38:32
  • Flask蓝图学习教程

    2023-03-02 04:19:15
  • ElementUI日期选择器时间选择范围限制的实现

    2024-04-09 11:00:28
  • python使用Flask框架获取用户IP地址的方法

    2023-08-09 03:15:23
  • asp之家 网络编程 m.aspxhome.com