PyTorch中torch.tensor()和torch.to_tensor()的区别

作者:Enzo?想砸电脑 时间:2022-11-18 11:59:42 

前言

在跑模型的时候,遇到如下报错

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).

网上查了一下,发现将 torch.tensor() 改写成 torch.as_tensor() 就可以避免报错了。

# 如下写法报错
feature = torch.tensor(image, dtype=torch.float32)

# 改为
feature = torch.as_tensor(image, dtype=torch.float32)

然后就又仔细研究了下 torch.as_tensor()torch.tensor() 的区别,在此记录。

1、torch.as_tensor()

new_data = torch.as_tensor(data, dtype=None,device=None)->Tensor

作用:生成一个新的 tensor, 这个新生成的tensor 会根据原数据的实际情况,来决定是进行浅拷贝,还是深拷贝。当然,会优先浅拷贝,浅拷贝会共享内存,并共享 autograd 历史记录。

情况一:数据类型相同 且 device相同,会进行浅拷贝,共享内存

import numpy
import torch

a = numpy.array([1, 2, 3])
t = torch.as_tensor(a)
t[0] = -1

print(a)   # [-1  2  3]
print(a.dtype)   # int64
print(t)   # tensor([-1,  2,  3])
print(t.dtype)   # torch.int64
import numpy
import torch

a = torch.tensor([1, 2, 3], device=torch.device('cuda'))
t = torch.as_tensor(a)
t[0] = -1

print(a)   # tensor([-1,  2,  3], device='cuda:0')
print(t)   # tensor([-1,  2,  3], device='cuda:0')

情况二: 数据类型相同,但是device不同,深拷贝,不再共享内存

import numpy
import torch

import numpy
a = numpy.array([1, 2, 3])
t = torch.as_tensor(a, device=torch.device('cuda'))
t[0] = -1

print(a)   # [1 2 3]
print(a.dtype)   # int64
print(t)   # tensor([-1,  2,  3], device='cuda:0')
print(t.dtype)   # torch.int64

情况三:device相同,但数据类型不同,深拷贝,不再共享内存

import numpy
import torch

a = numpy.array([1, 2, 3])
t = torch.as_tensor(a, dtype=torch.float32)
t[0] = -1

print(a)   # [1 2 3]
print(a.dtype)   # int64
print(t)   # tensor([-1.,  2.,  3.])
print(t.dtype)   # torch.float32

2、torch.tensor()

torch.tensor() 是深拷贝方式。

torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)

深拷贝:会拷贝 数据类型 和 device,不会记录 autograd 历史 (also known as a “leaf tensor” 叶子tensor)

重点是:

  • 如果原数据的数据类型是:list, tuple, NumPy ndarray, scalar, and other types,不会 waring

  • 如果原数据的数据类型是:tensor,使用 torch.tensor(data) 就会报waring

# 原数据类型是:tensor 会发出警告
import numpy
import torch

a = torch.tensor([1, 2, 3], device=torch.device('cuda'))
t = torch.tensor(a)
t[0] = -1

print(a)
print(t)

# 输出:
# tensor([1, 2, 3], device='cuda:0')
# tensor([-1,  2,  3], device='cuda:0')
# /opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
# 原数据类型是:list, tuple, NumPy ndarray, scalar, and other types, 没警告
import torch
import numpy

a =  numpy.array([1, 2, 3])
t = torch.tensor(a)

b = [1,2,3]
t= torch.tensor(b)

c = (1,2,3)
t= torch.tensor(c)

结论就是:以后尽量用 torch.as_tensor()

来源:https://blog.csdn.net/weixin_37804469/article/details/128767214

标签:torch.tensor,torch.Tensor,区别
0
投稿

猜你喜欢

  • overflow的另类用法

    2008-07-02 12:29:00
  • 胜过语言的图形符号

    2009-05-06 12:43:00
  • python虚拟环境的安装配置图文教程

    2023-09-23 09:03:04
  • python实现机器学习之多元线性回归

    2022-09-04 01:42:49
  • 最炫Python烟花代码全解析

    2022-02-16 13:07:53
  • python中input()的用法及扩展

    2021-07-05 08:45:57
  • Python 动态变量名定义与调用方法

    2023-07-29 22:36:05
  • 使用access数据库时可能用到的数据转换

    2008-09-10 12:49:00
  • asp如何显示最后十名来访者信息?

    2010-06-09 18:45:00
  • python处理json字符串(使用json.loads而不是eval())

    2023-06-13 11:50:39
  • 用python实现监控视频人数统计

    2022-04-03 16:01:31
  • 一文带你了解Golang中的缓冲区Buffer

    2024-04-23 09:47:18
  • 一文带你了解Python中的字符串是什么

    2021-10-16 06:05:27
  • python3中利用filter函数输出小于某个数的所有回文数实例

    2022-05-01 13:08:07
  • python计算n的阶乘的方法代码

    2023-08-20 07:33:00
  • 再说淘宝的评价和信用机制

    2008-07-10 12:43:00
  • opencv 实现特定颜色线条提取与定位操作

    2023-09-07 01:24:26
  • 详解Mysql双机热备和负载均衡的实现步骤

    2024-01-15 09:00:50
  • FF下,用 col 隐藏表格列的方法详解!

    2008-04-02 11:35:00
  • 关于JS中变量的显式申明和隐式申明

    2008-09-12 13:04:00
  • asp之家 网络编程 m.aspxhome.com