PyTorch基本数据类型(一)

作者:Liam Coder 时间:2023-06-15 20:56:39 

PyTorch基础入门一:PyTorch基本数据类型

1)Tensor(张量)

Pytorch里面处理的最基本的操作对象就是Tensor(张量),它表示的其实就是一个多维矩阵,并有矩阵相关的运算操作。在使用上和numpy是对应的,它和numpy唯一的不同就是,pytorch可以在GPU上运行,而numpy不可以。所以,我们也可以使用Tensor来代替numpy的使用。当然,二者也可以相互转换。

Tensor的基本数据类型有五种:

  • 32位浮点型:torch.FloatTensor。pyorch.Tensor()默认的就是这种类型。

  • 64位整型:torch.LongTensor。

  • 32位整型:torch.IntTensor。

  • 16位整型:torch.ShortTensor。

  • 64位浮点型:torch.DoubleTensor。

那么如何定义Tensor张量呢?其实定义的方式和numpy一样,直接传入相应的矩阵即可即可。下面就定义了一个三行两列的矩阵:


import torch
# 导包

a = torch.Tensor([[1, 2], [3, 4], [5, 6]])
print(a)

不过在项目之中,更多的做法是以特殊值或者随机值初始化一个矩阵,就像下面这样:


import torch

# 定义一个3行2列的全为0的矩阵
b = torch.zeros((3, 2))

# 定义一个3行2列的随机值矩阵
c = torch.randn((3, 2))

# 定义一个3行2列全为1的矩阵
d = torch.ones((3, 2))

print(b)
print(c)
print(d)

Tensor和numpy.ndarray之间还可以相互转换,其方式如下:

  • Numpy转化为Tensor:torch.from_numpy(numpy矩阵)

  • Tensor转化为numpy:Tensor矩阵.numpy()

范例如下:


import torch
import numpy as np

# 定义一个3行2列的全为0的矩阵
b = torch.randn((3, 2))

# tensor转化为numpy
numpy_b = b.numpy()
print(numpy_b)

# numpy转化为tensor
numpy_e = np.array([[1, 2], [3, 4], [5, 6]])
torch_e = torch.from_numpy(numpy_e)

print(numpy_e)
print(torch_e)

之前说过,numpy与Tensor最大的区别就是在对GPU的支持上。Tensor只需要调用cuda()函数就可以将其转化为能在GPU上运行的类型。

我们可以通过torch.cuda.is_available()函数来判断当前的环境是否支持GPU,如果支持,则返回True。所以,为保险起见,在项目代码中一般采取“先判断,后使用”的策略来保证代码的正常运行,其基本结构如下:


import torch

# 定义一个3行2列的全为0的矩阵
tmp = torch.randn((3, 2))

# 如果支持GPU,则定义为GPU类型
if torch.cuda.is_available():
 inputs = tmp.cuda()
# 否则,定义为一般的Tensor类型
else:
 inputs = tmp

2)Variable(变量)

Pytorch里面的Variable类型数据功能更加强大,相当于是在Tensor外层套了一个壳子,这个壳子赋予了前向传播,反向传播,自动求导等功能,在计算图的构建中起的很重要的作用。Variable的结构图如下:

PyTorch基本数据类型(一)

其中最重要的两个属性是:data和grad。Data表示该变量保存的实际数据,通过该属性可以访问到它所保存的原始张量类型,而关于该 variable(变量)的梯度会被累计到.grad 上去。

在使用Variable的时候需要从torch.autograd中导入。下面通过一个例子来看一下它自动求导的过程:


import torch
from torch.autograd import Variable

# 定义三个Variable变量
x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
w = Variable(torch.Tensor([2, 3, 4]), requires_grad=True)
b = Variable(torch.Tensor([3, 4, 5]), requires_grad=True)

# 构建计算图,公式为:y = w * x^2 + b
y = w * x * x + b

# 自动求导,计算梯度
y.backward(torch.Tensor([1, 1, 1]))

print(x.grad)
print(w.grad)
print(b.grad)

上述代码的计算图为y = w * x^2 + b。对x, w, b分别求偏导为:x.grad = 2wx,w.grad=x^2,b.grad=1。代值检验可得计算结果是正确的。

来源:https://blog.csdn.net/out_of_memory_error/article/details/81258809

标签:PyTorch,数据类型
0
投稿

猜你喜欢

  • jmeter实现接口关联的两种方式(正则表达式提取器和json提取器)

    2022-06-14 19:31:27
  • ASP 隐藏下载地址及防盗链代码

    2011-02-26 11:17:00
  • python获取全国最新省市区数据并存入表实例代码

    2021-10-19 14:16:23
  • JS实现求5的阶乘示例

    2024-04-30 10:09:43
  • Window下安装JDK1.8+Tomcat9.0.27+Mysql5.7.28的教程图解

    2024-01-26 22:24:12
  • 利用Python查看目录中的文件示例详解

    2023-02-06 14:13:28
  • 详解Tensorflow数据读取有三种方式(next_batch)

    2023-08-10 07:30:42
  • python判断数字是否是超级素数幂

    2023-12-24 06:16:31
  • python中随机函数random用法实例

    2023-02-09 22:13:10
  • 详解用Python处理HTML转义字符的5种方式

    2021-01-27 20:53:17
  • pandas多级分组实现排序的方法

    2022-05-06 14:16:11
  • 一文详解Python中的super 函数

    2022-02-26 03:18:35
  • Python实现多个视频合成一个视频的功能

    2021-10-31 12:57:44
  • Python搭建代理IP池实现接口设置与整体调度

    2023-05-25 11:52:03
  • 2010怎么就宅了——我们是设计星球的阿凡达

    2010-03-09 13:26:00
  • Pycharm 如何一键加引号的方法步骤

    2022-09-11 19:19:34
  • 使用PyCharm安装pytest及requests的问题

    2023-01-17 13:10:46
  • Python的Django框架中的表单处理示例

    2023-02-06 18:57:31
  • 如何利用Python给自己的头像加一个小国旗(小月饼)

    2023-09-15 20:47:31
  • Mysql8.0.22解压版安装教程(小白专用)

    2024-01-15 15:37:29
  • asp之家 网络编程 m.aspxhome.com