Pytorch中torch.flatten()和torch.nn.Flatten()实例详解
作者:有人比我慢吗 时间:2021-09-15 06:39:43
torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化。
import torch
x=torch.randn(2,4,2)
print(x)
z=torch.flatten(x)
print(z)
w=torch.flatten(x,1)
print(w)
输出为:
tensor([[[-0.9814, 0.8251],
[ 0.8197, -1.0426],
[-0.8185, -1.3367],
[-0.6293, 0.6714]],
[[-0.5973, -0.0944],
[ 0.3720, 0.0672],
[ 0.2681, 1.8025],
[-0.0606, 0.4855]]])
tensor([-0.9814, 0.8251, 0.8197, -1.0426, -0.8185, -1.3367, -0.6293, 0.6714,
-0.5973, -0.0944, 0.3720, 0.0672, 0.2681, 1.8025, -0.0606, 0.4855])
tensor([[-0.9814, 0.8251, 0.8197, -1.0426, -0.8185, -1.3367, -0.6293, 0.6714]
,
[-0.5973, -0.0944, 0.3720, 0.0672, 0.2681, 1.8025, -0.0606, 0.4855]
])
torch.flatten(x,0,1)代表在第一维和第二维之间平坦化
import torch
x=torch.randn(2,4,2)
print(x)
w=torch.flatten(x,0,1) #第一维长度2,第二维长度为4,平坦化后长度为2*4
print(w.shape)
print(w)
输出为:
tensor([[[-0.5523, -0.1132],
[-2.2659, -0.0316],
[ 0.1372, -0.8486],
[-0.3593, -0.2622]],
[[-0.9130, 1.0038],
[-0.3996, 0.4934],
[ 1.7269, 0.8215],
[ 0.1207, -0.9590]]])
torch.Size([8, 2])
tensor([[-0.5523, -0.1132],
[-2.2659, -0.0316],
[ 0.1372, -0.8486],
[-0.3593, -0.2622],
[-0.9130, 1.0038],
[-0.3996, 0.4934],
[ 1.7269, 0.8215],
[ 0.1207, -0.9590]])
对于torch.nn.Flatten(),因为其被用在神经网络中,输入为一批数据,第一维为batch,通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第二维开始平坦化。
import torch
#随机32个通道为1的5*5的图
x=torch.randn(32,1,5,5)
model=torch.nn.Sequential(
#输入通道为1,输出通道为6,3*3的卷积核,步长为1,padding=1
torch.nn.Conv2d(1,6,3,1,1),
torch.nn.Flatten()
)
output=model(x)
print(output.shape) # 6*(7-3+1)*(7-3+1)
输出为:
torch.Size([32, 150])
来源:https://blog.csdn.net/Super_user_and_woner/article/details/120782656
标签:pytorch,torch.flatten(),torch.nn.flatten()
0
投稿
猜你喜欢
python修改txt文件中的某一项方法
2021-02-08 14:26:40
Python读取图片为16进制表示简单代码
2021-07-24 09:34:15
使用Numpy读取CSV文件,并进行行列删除的操作方法
2023-05-05 03:26:11
MySQL中slave_exec_mode参数详解
2024-01-18 07:36:34
Python使用time模块实现指定时间触发器示例
2022-05-13 02:57:59
vue如何根据权限生成动态路由、导航栏
2024-05-05 09:25:43
在MySQL字段中使用逗号分隔符的方法分享
2024-01-17 23:34:19
Python tkinter 下拉日历控件代码
2023-10-25 06:07:18
Python正确调用 jar 包加密得到加密值的操作方法
2021-02-26 12:12:19
vue3+vite使用jsx和tsx详情
2024-05-10 14:15:47
Mysql动态更新数据库脚本的示例讲解
2024-01-23 11:22:49
python实现七段数码管和倒计时效果
2021-12-22 20:04:01
Python开发装包八种方法详解
2021-01-12 22:27:28
pytorch实现textCNN的具体操作
2022-08-28 17:40:00
在一个网站下再以虚拟目录的方式挂多个网站的方法
2023-07-24 01:03:57
MySQL8.0/8.x忘记密码更改root密码的实战步骤(亲测有效!)
2024-01-27 07:04:26
SQL Server 高性能写入的一些经验总结
2024-01-21 10:46:37
MySQL中where 1=1方法的使用及改进
2024-01-17 22:00:59
解读JavaScript代码 var ie = !-[1,] 最短的IE判定代码
2011-06-06 10:29:00
简单瞅瞅Python vars()内置函数的实现
2021-03-29 20:41:05