PyTorch小功能之TensorDataset解读

作者:菜鸟向前冲fighting 时间:2023-02-26 06:06:27 

PyTorch之TensorDataset

TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。

该类通过每一个 tensor 的第一个维度进行索引。

因此,该类中的 tensor 第一维度必须相等。

from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b)
# 切片输出
print(train_ids[0:2])
print('=' * 80)
# 循环取数据
for x_train, y_label in train_ids:
   print(x_train, y_label)
# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1):  # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
   x_data, label = data
   print(' batch:{0} x_data:{1}  label: {2}'.format(i, x_data, label))

运行结果:

(tensor([[1, 2, 3],
        [4, 5, 6]]), tensor([44, 55]))
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
 batch:1 x_data:tensor([[1, 2, 3],
        [1, 2, 3],
        [4, 5, 6],
        [4, 5, 6]])  label: tensor([44, 44, 55, 55])
 batch:2 x_data:tensor([[4, 5, 6],
        [7, 8, 9],
        [7, 8, 9],
        [7, 8, 9]])  label: tensor([55, 66, 66, 66])
 batch:3 x_data:tensor([[1, 2, 3],
        [1, 2, 3],
        [7, 8, 9],
        [4, 5, 6]])  label: tensor([44, 44, 66, 55])

注意:TensorDataset 中的参数必须是 tensor

Pytorch中TensorDataset的快速使用

Pytorch中,TensorDataset()可以快速构建训练所用的数据,不用使用自建的Mydataset(),如果没有熟悉适用的dataset可以使用TensorDataset()作为暂时替代。

只需要把data和label作为参数输入,就可以快速构建,之后便可以用Dataloader处理。

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
data = np.loadtxt('x.txt')
label = np.loadtxt('y.txt')
data = torch.tensor(data)
label = torch.tensor(label)
train_data = TensorDataset(data, label)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) 

来源:https://blog.csdn.net/qq_40211493/article/details/107529148

标签:PyTorch,TensorDataset
0
投稿

猜你喜欢

  • MySQL中SQL的单字节注入与宽字节注入

    2009-03-25 14:49:00
  • 使用CSS3和RGBa创建超酷的按钮

    2009-06-02 12:41:00
  • Sql中将datetime转换成字符串的方法(CONVERT)

    2024-01-22 10:25:59
  • Vue+Element自定义纵向表格表头教程

    2023-07-02 17:10:38
  • 详解Node.js中的事件机制

    2024-05-03 15:58:52
  • Python实现简易的图书管理系统

    2021-09-12 06:06:21
  • 纯JS实现动态时间显示代码

    2024-05-02 17:31:34
  • JS关于刷新页面的相关总结

    2024-04-22 12:52:36
  • python WindowsError的错误代码详解

    2021-09-03 18:58:45
  • Golang通过包长协议处理TCP粘包的问题解决

    2024-04-30 10:00:11
  • asp随机获取access数据库中的一条记录

    2007-08-15 13:11:00
  • Python爬虫:url中带字典列表参数的编码转换方法

    2021-11-02 17:50:45
  • MySQL数据库自动补全命令的三种方法

    2024-01-26 16:58:35
  • pytorch loss反向传播出错的解决方案

    2023-04-06 07:20:55
  • mysql中复制表结构的方法小结

    2024-01-19 22:54:26
  • Python使用ctypes调用C/C++的方法

    2023-09-01 21:27:47
  • python中使用ctypes调用so传参设置遇到的问题及解决方法

    2021-06-02 00:38:39
  • Python xlrd/xlwt 创建excel文件及常用操作

    2021-08-17 04:33:22
  • 利用scrapy将爬到的数据保存到mysql(防止重复)

    2024-01-23 15:35:28
  • js小方框中浏览大图类似google earth效果

    2007-10-28 19:30:00
  • asp之家 网络编程 m.aspxhome.com