PyTorch数据读取的实现示例

作者:YXHPY 时间:2022-01-31 04:15:48 

前言

PyTorch作为一款深度学习框架,已经帮助我们实现了很多很多的功能了,包括数据的读取和转换了,那么这一章节就介绍一下PyTorch内置的数据读取模块吧

模块介绍

  • pandas 用于方便操作含有字符串的表文件,如csv

  • zipfile python内置的文件解压包

  • cv2 用于图片处理的模块,读入的图片模块为BGR,N H W C

  • torchvision.transforms 用于图片的操作库,比如随机裁剪、缩放、模糊等等,可用于数据的增广,但也不仅限于内置的图片操作,也可以自行进行图片数据的操作,这章也会讲解

  • torch.utils.data.Dataset torch内置的对象类型

  • torch.utils.data.DataLoader 和Dataset配合使用可以实现数据的加速读取和随机读取等等功能


import zipfile # 解压
import pandas as pd # 操作数据
import os # 操作文件或文件夹
import cv2 # 图像操作库
import matplotlib.pyplot as plt # 图像展示库
from torch.utils.data import Dataset # PyTorch内置对象
from torchvision import transforms # 图像增广转换库 PyTorch内置
import torch

初步读取数据

数据下载到此处
我们先初步编写一个脚本来实现图片的展示


# 解压文件到指定目录
def unzip_file(root_path, filename):
 full_path = os.path.join(root_path, filename)
 file = zipfile.ZipFile(full_path)
 file.extractall(root_path)
unzip_file(root_path, zip_filename)

# 读入csv文件
face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))

# pandas读出的数据如想要操作索引 使用iloc
image_name = face_landmarks.iloc[:,0]
landmarks = face_landmarks.iloc[:,1:]

# 展示
def show_face(extract_path, image_file, face_landmark):
 plt.imshow(plt.imread(os.path.join(extract_path, image_file)), cmap='gray')
 point_x = face_landmark.to_numpy()[0::2]
 point_y = face_landmark.to_numpy()[1::2]
 plt.scatter(point_x, point_y, c='r', s=6)

show_face(extract_path, image_name.iloc[1], landmarks.iloc[1])

PyTorch数据读取的实现示例

使用内置库来实现

实现MyDataset

使用内置库是我们的代码更加的规范,并且可读性也大大增加
继承Dataset,需要我们实现的有两个地方:

  • 实现__len__返回数据的长度,实例化调用len()时返回

  • __getitem__给定数据的索引返回对应索引的数据如:a[0]

  • transform 数据的额外操作时调用


class FaceDataset(Dataset):
 def __init__(self, extract_path, csv_filename, transform=None):
   super(FaceDataset, self).__init__()
   self.extract_path = extract_path
   self.csv_filename = csv_filename
   self.transform = transform
   self.face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))
 def __len__(self):
   return len(self.face_landmarks)
 def __getitem__(self, idx):
   image_name = self.face_landmarks.iloc[idx,0]
   landmarks = self.face_landmarks.iloc[idx,1:].astype('float32')
   point_x = landmarks.to_numpy()[0::2]
   point_y = landmarks.to_numpy()[1::2]
   image = plt.imread(os.path.join(self.extract_path, image_name))
   sample = {'image':image, 'point_x':point_x, 'point_y':point_y}
   if self.transform is not None:
     sample = self.transform(sample)
   return sample

测试功能是否正常


face_dataset = FaceDataset(extract_path, csv_filename)
sample = face_dataset[0]
plt.imshow(sample['image'], cmap='gray')
plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)
plt.title('face')

PyTorch数据读取的实现示例

实现自己的数据处理模块

内置的在torchvision.transforms模块下,由于我们的数据结构不能满足内置模块的要求,我们就必须自己实现
图片的缩放,由于缩放后人脸的标注位置也应该发生对应的变化,所以要自己实现对应的变化


class Rescale(object):
 def __init__(self, out_size):
   assert isinstance(out_size,tuple) or isinstance(out_size,int), 'out size isinstance int or tuple'
   self.out_size = out_size
 def __call__(self, sample):
   image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']
   new_h, new_w = self.out_size if isinstance(self.out_size,tuple) else (self.out_size, self.out_size)
   new_image = cv2.resize(image,(new_w, new_h))
   h, w = image.shape[0:2]
   new_y = new_h / h * point_y
   new_x = new_w / w * point_x
   return {'image':new_image, 'point_x':new_x, 'point_y':new_y}

将数据转换为torch认识的数据格式因此,就必须转换为tensor
注意: cv2matplotlib读出的图片默认的shape为N H W C,而torch默认接受的是N C H W因此使用tanspose转换维度,torch转换多维度使用permute


class ToTensor(object):
 def __call__(self, sample):
   image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']
   new_image = image.transpose((2,0,1))
   return {'image':torch.from_numpy(new_image), 'point_x':torch.from_numpy(point_x), 'point_y':torch.from_numpy(point_y)}

测试


transform = transforms.Compose([Rescale((1024, 512)), ToTensor()])
face_dataset = FaceDataset(extract_path, csv_filename, transform=transform)
sample = face_dataset[0]
plt.imshow(sample['image'].permute((1,2,0)), cmap='gray')
plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)
plt.title('face')

PyTorch数据读取的实现示例

使用Torch内置的loader加速读取数据


data_loader = DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=0)
for i in data_loader:
 print(i['image'].shape)
 break

torch.Size([4, 3, 1024, 512])

注意: windows环境尽量不使用num_workers会发生报错

来源:https://blog.csdn.net/weixin_42263486/article/details/108295120

标签:PyTorch,数据读取
0
投稿

猜你喜欢

  • Python从使用线程到使用async/await的深入讲解

    2021-07-26 10:56:11
  • 如何在Python中进行异常处理

    2021-02-21 06:51:01
  • 通过表单的做为二进制文件上传request.totalbytes提取出上传的二级制数据

    2011-03-16 10:39:00
  • python 将列表中的字符串连接成一个长路径的方法

    2023-04-18 14:06:25
  • sina和265天气预报调用代码

    2007-11-19 13:32:00
  • Python常问的100个面试问题汇总(下篇)

    2023-09-23 06:30:29
  • python3.9实现pyinstaller打包python文件成exe

    2022-10-28 18:27:35
  • Python 读取串口数据,动态绘图的示例

    2021-11-15 19:36:24
  • 在ASP.NET 2.0中操作数据之十二:在GridView控件中使用TemplateField

    2023-07-07 07:02:50
  • Python爬虫基础之XPath语法与lxml库的用法详解

    2022-07-03 20:56:06
  • python制作花瓣网美女图片爬虫

    2023-05-20 01:51:55
  • Django处理Ajax发送的Get请求代码详解

    2023-06-29 08:40:40
  • Python中无限循环需要什么条件

    2023-03-28 09:05:14
  • uniqueidentifier转换成varchar数据类型的sql语句

    2011-09-30 11:17:48
  • python的Jenkins接口调用方式

    2022-02-23 11:26:23
  • 使用python脚本实现查询火车票工具

    2021-03-10 05:48:54
  • Pytorch 之修改Tensor部分值方式

    2023-04-11 06:45:12
  • python+gdal+遥感图像拼接(mosaic)的实例

    2023-02-22 23:40:34
  • .NET Core2.1如何获取自定义配置文件信息详解

    2023-07-17 16:26:34
  • Python2与Python3的区别实例分析

    2021-01-07 11:47:17
  • asp之家 网络编程 m.aspxhome.com