Pytorch平均池化nn.AvgPool2d()使用方法实例

作者:Cassiel_cx 时间:2023-09-30 02:49:35 

【pytorch官方文档】:https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html?highlight=avgpool2d#torch.nn.AvgPool2d

torch.nn.AvgPool2d()

作用

在由多通道组成的输入特征中进行2D平均池化计算

函数

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

参数

Args:
    kernel_size: 滑窗(池化核)大小
    stride: 滑窗的移动步长, 默认值为kernel_size
    padding: 在输入信号两侧的隐式零填充数量
    ceil_mode: 决定计算输出的形状时是向上取整还是向下取整, 默认为False(向下取整)
    count_include_pad: 在平均池化计算中是否包含零填充, 默认为True(包含零填充)
    divisor_override: 如果指定了, 它将被作为平均池化计算中的除数, 否则将使用池化区域的大小作为平均池化计算的除数

公式

Pytorch平均池化nn.AvgPool2d()使用方法实例

代码实例

假设输入特征为S,输出特征为D

情况一

ceil_mode=False, count_include_pad=True(计算时包含零填充)

import torch
import torch.nn as nn
import numpy as np

# 生成一个形状为1*1*3*3的张量
x1 = np.array([
             [1,2,3],
             [4,5,6],
             [7,8,9]
           ])
x1 = torch.from_numpy(x1).float()
x1 = x1.unsqueeze(0).unsqueeze(0)

# 实例化二维平均池化
avgpool1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=True)
y1 = avgpool1(x1)
print(y1)

# 打印结果
'''
tensor([[[[1.3333, 1.7778],
         [2.6667, 3.1111]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (0+0+0+0+1+2+0+4+5) / 9 = 1.3333,

D[1,2] = (0+0+0+2+3+0+5+6+0) / 9 = 1.7778,

D[2,1] = (0+4+5+0+7+8+0+0+0) / 9 = 2.6667,

D[2,2] = (5+6+0+8+9+0+0+0+0) / 9 = 3.1111.

情况二

ceil_mode=False, count_include_pad=False(计算时不包含零填充)

avgpool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=False)

y2 = avgpool2(x1)
print(y2)

# 打印结果
'''
tensor([[[[3., 4.],
         [6., 7.]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (1+2+4+5) / 4 = 3,

D[1,2] = (2+3+5+6) / 4 = 4,

D[2,1] = (4+5+7+8) / 4 = 6,

D[2,2] = (5+6+8+9) / 4 = 7.

情况三

ceil_mode=False, count_include_pad=False, divisor_override=2(将计算平均池化时的除数指定为2)

avgpool3 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=False, divisor_override=2)

y3 = avgpool3(x1)
print(y3)

# 打印结果
'''
tensor([[[[ 6.,  8.],
         [12., 14.]]]])
'''

计算过程:

输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,

D[1,1] = (1+2+4+5) / 2 = 6,

D[1,2] = (2+3+5+6) / 2 = 8,

D[2,1] = (4+5+7+8) / 2 = 12,

D[2,2] = (5+6+8+9) / 2 = 14.

情况四

ceil_mode=True, count_include_pad=True, divisor_override=None(在计算输出的形状时向上取整)

x2 = np.array([
             [1,2,3,4],
             [5,6,7,8],
             [9,10,11,12],
             [13,14,15,16]
             ])
x2 = torch.from_numpy(x2).reshape(1,1,4,4).float()
avgpool4 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)
y4 = avgpool4(x2)
print(y4)

# 打印结果
'''
tensor([[[[ 1.5556,  3.3333,  2.0000],
         [ 6.3333, 11.0000,  6.0000],
         [ 4.5000,  7.5000,  4.0000]]]])
'''

计算过程:

输出形状 = ceil[(4 - 3 + 2) / 2] + 1 = 3,

D[1,1] = (0+0+0+0+1+2+0+5+6) / 9 = 1.5556,

D[1,2] = (0+0+0+2+3+4+6+7+8) / 9 = 3.3333,

Pytorch平均池化nn.AvgPool2d()使用方法实例

D[1,3] = (0+0+4+0+8+0) / 6 = 2,

D[2,1] = (0+5+6+0+9+10+0+13+14) / 9 = 6.3333,

D[2,2] = (6+7+8+10+11+12+14+15+16) / 9 = 11,

Pytorch平均池化nn.AvgPool2d()使用方法实例

D[2,3] = (8+0+12+0+16+0) / 6 = 6,

Pytorch平均池化nn.AvgPool2d()使用方法实例

D[3,1] = (0+13+14+0+0+0) / 6 = 4.5,

D[3,2] = (14+15+16+0+0+0) / 6 = 7.5,

Pytorch平均池化nn.AvgPool2d()使用方法实例

D[3,3] = (16+0+0+0) / 4 = 4.

来源:https://blog.csdn.net/qq_38964360/article/details/129148451

标签:pytorch,平均池化,nn.avgpool2d()
0
投稿

猜你喜欢

  • python利用datetime模块计算时间差

    2021-10-07 01:02:04
  • jQuery中$.get、$.post、$.getJSON和$.ajax的用法详解

    2024-04-16 08:54:20
  • python代码实现TSNE降维数据可视化教程

    2023-09-08 16:50:37
  • vue中子组件调用兄弟组件方法

    2024-04-30 10:24:44
  • JS实现点击表头表格自动排序(含数字、字符串、日期)

    2024-05-02 16:16:53
  • JavaScript观察者模式原理与用法实例详解

    2024-04-19 10:02:48
  • Python一行代码快速实现程序进度条示例

    2022-07-07 07:22:26
  • 我的css样式写法总结

    2009-01-18 13:04:00
  • 20行Python代码实现视频字符化功能

    2023-01-08 21:17:02
  • Python的字典和列表的使用中一些需要注意的地方

    2023-01-09 02:03:52
  • PDO::setAttribute讲解

    2023-06-05 18:04:23
  • 十万条Access数据表分页的两个解决方法

    2008-05-23 18:24:00
  • 浅谈keras中的目标函数和优化函数MSE用法

    2022-01-19 02:15:55
  • MySQL中连接查询和子查询的问题

    2024-01-19 04:27:32
  • golang 监听服务的信号,实现平滑启动,linux信号说明详解

    2024-05-09 10:00:43
  • 人性化网页设计技巧

    2007-10-15 13:02:00
  • Python将8位的图片转为24位的图片实现方法

    2021-07-31 12:11:42
  • 模仿MSN消息提示的效果

    2013-07-02 06:22:28
  • 网页设计标准尺寸

    2008-06-15 15:21:00
  • 10款最好的Web开发的 Python 框架

    2023-12-21 11:26:04
  • asp之家 网络编程 m.aspxhome.com