pytorch SENet实现案例

作者:小伟db 时间:2021-03-27 05:14:23 

我就废话不多说了,大家还是直接看代码吧~


from torch import nn

class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
 super(SELayer, self).__init__()

//返回1X1大小的特征图,通道数不变
 self.avg_pool = nn.AdaptiveAvgPool2d(1)
 self.fc = nn.Sequential(
  nn.Linear(channel, channel // reduction, bias=False),
  nn.ReLU(inplace=True),
  nn.Linear(channel // reduction, channel, bias=False),
  nn.Sigmoid()
 )

def forward(self, x):
 b, c, _, _ = x.size()

//全局平均池化,batch和channel和原来一样保持不变
 y = self.avg_pool(x).view(b, c)

//全连接层+池化
 y = self.fc(y).view(b, c, 1, 1)

//和原特征图相乘
 return x * y.expand_as(x)

补充知识:pytorch 实现 SE Block

论文模块图

pytorch SENet实现案例

代码


import torch.nn as nn
class SE_Block(nn.Module):
def __init__(self, ch_in, reduction=16):
 super(SE_Block, self).__init__()
 self.avg_pool = nn.AdaptiveAvgPool2d(1)# 全局自适应池化
 self.fc = nn.Sequential(
  nn.Linear(ch_in, ch_in // reduction, bias=False),
  nn.ReLU(inplace=True),
  nn.Linear(ch_in // reduction, ch_in, bias=False),
  nn.Sigmoid()
 )

def forward(self, x):
 b, c, _, _ = x.size()
 y = self.avg_pool(x).view(b, c)
 y = self.fc(y).view(b, c, 1, 1)
 return x * y.expand_as(x)

现在还有许多关于SE的变形,但大都大同小异

来源:https://blog.csdn.net/qq_35985044/article/details/90142431

标签:pytorch,SENet
0
投稿

猜你喜欢

  • Python使用20行代码实现微信聊天机器人

    2023-12-04 12:52:06
  • 一个简单的python爬虫程序 爬取豆瓣热度Top100以内的电影信息

    2023-01-09 19:50:16
  • Python matplotlib生成图片背景透明的示例代码

    2022-07-04 06:22:57
  • 快速解决Django关闭Debug模式无法加载media图片与static静态文件

    2023-05-28 02:54:43
  • sp_delete_backuphistory

    2008-06-07 13:59:00
  • tensorflow之并行读入数据详解

    2021-09-20 14:42:30
  • MySQL中order by的执行过程

    2024-01-15 00:29:16
  • 5款Python程序员高频使用开发工具推荐

    2022-01-25 14:09:16
  • oracle数据库创建备份与恢复脚本整理

    2023-07-13 00:57:20
  • 使用python进行图片的文字识别详细代码

    2021-06-27 07:01:06
  • Python标准库calendar的使用方法

    2023-09-01 01:28:07
  • python五子棋游戏的设计与实现

    2021-04-19 07:17:43
  • Django rest framework分页接口实现原理解析

    2023-08-22 21:26:36
  • 用Python制作一个可以聊天的皮卡丘版桌面宠物

    2021-12-05 10:44:41
  • Python采用raw_input读取输入值的方法

    2021-09-08 03:19:05
  • OpenCV-Python使用cv2实现傅里叶变换

    2023-07-08 05:11:06
  • Go语言中map使用和并发安全详解

    2024-04-26 17:21:00
  • 使用client-go工具调用kubernetes API接口的教程详解(v1.17版本)

    2024-05-08 10:52:08
  • vue+element项目中过滤输入框特殊字符小结

    2024-04-28 10:53:44
  • matlab和Excel的数据交互操作(非xlsread和xlswrite)

    2022-06-16 01:00:42
  • asp之家 网络编程 m.aspxhome.com