pytorch 数据加载性能对比分析

作者:ShellCollector 时间:2022-04-17 04:22:22 

传统方式需要10s,dat方式需要0.6s


import os
import time
import torch
import random
from common.coco_dataset import COCODataset
def gen_data(batch_size,data_path,target_path):
os.makedirs(target_path,exist_ok=True)
dataloader = torch.utils.data.DataLoader(COCODataset(data_path,
              (352, 352),
              is_training=False, is_scene=True),
           batch_size=batch_size,
           shuffle=False, num_workers=0, pin_memory=False,
           drop_last=True) # DataLoader
start = time.time()
for step, samples in enumerate(dataloader):
 images, labels, image_paths = samples["image"], samples["label"], samples["img_path"]
 print("time", images.size(0), time.time() - start)
 start = time.time()
 # torch.save(samples,target_path+ '/' + str(step) + '.dat')
 print(step)
def cat_100(target_path,batch_size=100):
paths = os.listdir(target_path)
li = [i for i in range(len(paths))]
random.shuffle(li)
images = []
labels = []
image_paths = []
start = time.time()
for i in range(len(paths)):
 samples = torch.load(target_path + str(li[i]) + ".dat")
 image, label, image_path = samples["image"], samples["label"], samples["img_path"]
 images.append(image.cuda())
 labels.append(label.cuda())
 image_paths.append(image_path)
 if i % batch_size == batch_size - 1:
  images = torch.cat((images), 0)
  print("time", images.size(0), time.time() - start)
  images = []
  labels = []
  image_paths = []
  start = time.time()
 i += 1
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
batch_size=320
# target_path='d:/test_1000/'
target_path='d:\img_2/'
data_path = r'D:\dataset\origin_all_datas\_2train'
gen_data(batch_size,data_path,target_path)
# get_data(target_path,batch_size)
# cat_100(target_path,batch_size)

这个读取数据也比较快:320 batch_size 450ms


def cat_100(target_path,batch_size=100):
paths = os.listdir(target_path)
li = [i for i in range(len(paths))]
random.shuffle(li)
images = []
labels = []
image_paths = []
start = time.time()
for i in range(len(paths)):
 samples = torch.load(target_path + str(li[i]) + ".dat")
 image, label, image_path = samples["image"], samples["label"], samples["img_path"]
 images.append(image)#.cuda())
 labels.append(label)#.cuda())
 image_paths.append(image_path)
 if i % batch_size < batch_size - 1:
  i += 1
  continue
 i += 1
 images = torch.cat(([image.cuda() for image in images]), 0)
 print("time", images.size(0), time.time() - start)
 images = []
 labels = []
 image_paths = []
 start = time.time()

补充:pytorch数据加载和处理问题解决方案

最近跟着pytorch中文文档学习遇到一些小问题,已经解决,在此对这些错误进行记录:

在读取数据集时报错:

AttributeError: 'Series' object has no attribute 'as_matrix'

在显示图片是时报错:

ValueError: Masked arrays must be 1-D

显示单张图片时figure一闪而过

在显示多张散点图的时候报错:

TypeError: show_landmarks() got an unexpected keyword argument 'image'

解决方案

主要问题在这一行: 最终目的是将Series转为Matrix,即调用np.mat即可完成。

修改前


landmarks =landmarks_frame.iloc[n, 1:].as_matrix()

修改后


landmarks =np.mat(landmarks_frame.iloc[n, 1:])

打散点的x和y坐标应该均为向量或列表,故将landmarks后使用tolist()方法即可

修改前


plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')

修改后


plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')

前面使用plt.ion()打开交互模式,则后面在plt.show()之前一定要加上plt.ioff()。这里直接加到函数里面,避免每次plt.show()之前都用plt.ioff()

修改前


def show_landmarks(imgs,landmarks):
'''显示带有地标的图片'''
plt.imshow(imgs)
plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点
plt.pause(1)#绘图窗口延时

修改后


def show_landmarks(imgs,landmarks):
'''显示带有地标的图片'''
plt.imshow(imgs)
plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点
plt.pause(1)#绘图窗口延时
plt.ioff()

网上说对于字典类型的sample可通过 **sample的方式获取每个键下的值,但是会报错,于是把输入写的详细一点,就成功了。

修改前


show_landmarks(**sample)

修改后


show_landmarks(sample['image'],sample['landmarks'])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。

来源:https://blog.csdn.net/jacke121/article/details/85236561

标签:pytorch,数据,加载
0
投稿

猜你喜欢

  • Python 找出出现次数超过数组长度一半的元素实例

    2023-06-07 05:50:33
  • python 模拟网站登录——滑块验证码的识别

    2023-04-17 16:16:29
  • python3通过udp实现组播数据的发送和接收操作

    2023-01-14 02:27:42
  • 实现php删除链表中重复的结点

    2023-09-05 09:36:15
  • Django-xadmin+rule对象级权限的实现方式

    2023-02-20 17:08:08
  • PyChon中关于Jekins的详细安装(推荐)

    2021-03-17 08:07:31
  • python 性能优化方法小结

    2022-08-04 21:13:43
  • python程序需要编译吗

    2022-08-18 12:57:13
  • python 数据类(dataclass)的具体使用

    2022-11-08 09:36:27
  • ​​​​​​​如何利用python破解zip加密文件

    2022-11-27 17:51:30
  • Python中常用的8种字符串操作方法

    2023-05-28 09:44:38
  • 使IE浏览器支持PNG格式图片的透明效果

    2008-02-02 16:20:00
  • PyQt5+Pycharm安装和配置图文教程详解

    2022-12-20 08:50:26
  • 在阿里云服务器上配置CentOS+Nginx+Python+Flask环境

    2023-07-26 09:47:46
  • python小程序之4名牌手洗牌发牌问题解析

    2023-08-28 04:06:20
  • 如何表示python中的相对路径

    2022-09-06 14:13:41
  • Flask使用SQLAlchemy实现持久化数据

    2023-02-23 07:47:19
  • Django文件上传与下载(FileFlid)

    2023-07-10 11:33:33
  • 50种方法巧妙优化SQL Server数据库

    2008-12-24 15:49:00
  • 用Python 爬取猫眼电影数据分析《无名之辈》

    2023-07-03 17:23:26
  • asp之家 网络编程 m.aspxhome.com