Pytorch中torch.flatten()和torch.nn.Flatten()实例详解

作者:有人比我慢吗 时间:2021-09-15 06:39:43 

 torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1)代表从第二维开始平坦化。

import torch
x=torch.randn(2,4,2)
print(x)

z=torch.flatten(x)
print(z)

w=torch.flatten(x,1)
print(w)

输出为:
tensor([[[-0.9814,  0.8251],
        [ 0.8197, -1.0426],
        [-0.8185, -1.3367],
        [-0.6293,  0.6714]],

[[-0.5973, -0.0944],
        [ 0.3720,  0.0672],
        [ 0.2681,  1.8025],
        [-0.0606,  0.4855]]])

tensor([-0.9814,  0.8251,  0.8197, -1.0426, -0.8185, -1.3367, -0.6293,  0.6714,
       -0.5973, -0.0944,  0.3720,  0.0672,  0.2681,  1.8025, -0.0606,  0.4855])

tensor([[-0.9814,  0.8251,  0.8197, -1.0426, -0.8185, -1.3367, -0.6293,  0.6714]
,
       [-0.5973, -0.0944,  0.3720,  0.0672,  0.2681,  1.8025, -0.0606,  0.4855]
])

 torch.flatten(x,0,1)代表在第一维和第二维之间平坦化

import torch
x=torch.randn(2,4,2)
print(x)

w=torch.flatten(x,0,1) #第一维长度2,第二维长度为4,平坦化后长度为2*4
print(w.shape)

print(w)

输出为:
tensor([[[-0.5523, -0.1132],
        [-2.2659, -0.0316],
        [ 0.1372, -0.8486],
        [-0.3593, -0.2622]],

[[-0.9130,  1.0038],
        [-0.3996,  0.4934],
        [ 1.7269,  0.8215],
        [ 0.1207, -0.9590]]])

torch.Size([8, 2])

tensor([[-0.5523, -0.1132],
       [-2.2659, -0.0316],
       [ 0.1372, -0.8486],
       [-0.3593, -0.2622],
       [-0.9130,  1.0038],
       [-0.3996,  0.4934],
       [ 1.7269,  0.8215],
       [ 0.1207, -0.9590]])

对于torch.nn.Flatten(),因为其被用在神经网络中,输入为一批数据,第一维为batch,通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第二维开始平坦化。

import torch
#随机32个通道为1的5*5的图
x=torch.randn(32,1,5,5)

model=torch.nn.Sequential(
   #输入通道为1,输出通道为6,3*3的卷积核,步长为1,padding=1
   torch.nn.Conv2d(1,6,3,1,1),
   torch.nn.Flatten()
)
output=model(x)
print(output.shape)  # 6*(7-3+1)*(7-3+1)

输出为:

torch.Size([32, 150])

来源:https://blog.csdn.net/Super_user_and_woner/article/details/120782656

标签:pytorch,torch.flatten(),torch.nn.flatten()
0
投稿

猜你喜欢

  • python修改txt文件中的某一项方法

    2021-02-08 14:26:40
  • Python读取图片为16进制表示简单代码

    2021-07-24 09:34:15
  • 使用Numpy读取CSV文件,并进行行列删除的操作方法

    2023-05-05 03:26:11
  • MySQL中slave_exec_mode参数详解

    2024-01-18 07:36:34
  • Python使用time模块实现指定时间触发器示例

    2022-05-13 02:57:59
  • vue如何根据权限生成动态路由、导航栏

    2024-05-05 09:25:43
  • 在MySQL字段中使用逗号分隔符的方法分享

    2024-01-17 23:34:19
  • Python tkinter 下拉日历控件代码

    2023-10-25 06:07:18
  • Python正确调用 jar 包加密得到加密值的操作方法

    2021-02-26 12:12:19
  • vue3+vite使用jsx和tsx详情

    2024-05-10 14:15:47
  • Mysql动态更新数据库脚本的示例讲解

    2024-01-23 11:22:49
  • python实现七段数码管和倒计时效果

    2021-12-22 20:04:01
  • Python开发装包八种方法详解

    2021-01-12 22:27:28
  • pytorch实现textCNN的具体操作

    2022-08-28 17:40:00
  • 在一个网站下再以虚拟目录的方式挂多个网站的方法

    2023-07-24 01:03:57
  • MySQL8.0/8.x忘记密码更改root密码的实战步骤(亲测有效!)

    2024-01-27 07:04:26
  • SQL Server 高性能写入的一些经验总结

    2024-01-21 10:46:37
  • MySQL中where 1=1方法的使用及改进

    2024-01-17 22:00:59
  • 解读JavaScript代码 var ie = !-[1,] 最短的IE判定代码

    2011-06-06 10:29:00
  • 简单瞅瞅Python vars()内置函数的实现

    2021-03-29 20:41:05
  • asp之家 网络编程 m.aspxhome.com