Python深度学习之使用Pytorch搭建ShuffleNetv2

作者:I 时间:2023-10-10 06:19:09 

一、model.py

1.1 Channel Shuffle

Python深度学习之使用Pytorch搭建ShuffleNetv2
Python深度学习之使用Pytorch搭建ShuffleNetv2
Python深度学习之使用Pytorch搭建ShuffleNetv2
Python深度学习之使用Pytorch搭建ShuffleNetv2


def channel_shuffle(x: Tensor, groups: int) -> Tensor:

batch_size, num_channels, height, width = x.size()
   channels_per_group = num_channels // groups

# reshape
   # [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
   x = x.view(batch_size, groups, channels_per_group, height, width)

x = torch.transpose(x, 1, 2).contiguous()

# flatten
   x = x.view(batch_size, -1, height, width)

return x

1.2 block

Python深度学习之使用Pytorch搭建ShuffleNetv2
Python深度学习之使用Pytorch搭建ShuffleNetv2
Python深度学习之使用Pytorch搭建ShuffleNetv2


class InvertedResidual(nn.Module):
   def __init__(self, input_c: int, output_c: int, stride: int):
       super(InvertedResidual, self).__init__()

if stride not in [1, 2]:
           raise ValueError("illegal stride value.")
       self.stride = stride

assert output_c % 2 == 0
       branch_features = output_c // 2
       # 当stride为1时,input_channel应该是branch_features的两倍
       # python中 '<<' 是位运算,可理解为计算×2的快速方法
       assert (self.stride != 1) or (input_c == branch_features << 1)

if self.stride == 2:
           self.branch1 = nn.Sequential(
               self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
               nn.BatchNorm2d(input_c),
               nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
               nn.BatchNorm2d(branch_features),
               nn.ReLU(inplace=True)
           )
       else:
           self.branch1 = nn.Sequential()

self.branch2 = nn.Sequential(
           nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
                     stride=1, padding=0, bias=False),
           nn.BatchNorm2d(branch_features),
           nn.ReLU(inplace=True),
           self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
           nn.BatchNorm2d(branch_features),
           nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
           nn.BatchNorm2d(branch_features),
           nn.ReLU(inplace=True)
       )

@staticmethod
   def depthwise_conv(input_c: int,
                      output_c: int,
                      kernel_s: int,
                      stride: int = 1,
                      padding: int = 0,
                      bias: bool = False) -> nn.Conv2d:
       return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
                        stride=stride, padding=padding, bias=bias, groups=input_c)

def forward(self, x: Tensor) -> Tensor:
       if self.stride == 1:
           x1, x2 = x.chunk(2, dim=1)
           out = torch.cat((x1, self.branch2(x2)), dim=1)
       else:
           out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

out = channel_shuffle(out, 2)

return out

1.3 shufflenet v2

Python深度学习之使用Pytorch搭建ShuffleNetv2
Python深度学习之使用Pytorch搭建ShuffleNetv2
Python深度学习之使用Pytorch搭建ShuffleNetv2
Python深度学习之使用Pytorch搭建ShuffleNetv2


class ShuffleNetV2(nn.Module):
   def __init__(self,
                stages_repeats: List[int],
                stages_out_channels: List[int],
                num_classes: int = 1000,
                inverted_residual: Callable[..., nn.Module] = InvertedResidual):
       super(ShuffleNetV2, self).__init__()

if len(stages_repeats) != 3:
           raise ValueError("expected stages_repeats as list of 3 positive ints")
       if len(stages_out_channels) != 5:
           raise ValueError("expected stages_out_channels as list of 5 positive ints")
       self._stage_out_channels = stages_out_channels

# input RGB image
       input_channels = 3
       output_channels = self._stage_out_channels[0]

self.conv1 = nn.Sequential(
           nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
           nn.BatchNorm2d(output_channels),
           nn.ReLU(inplace=True)
       )
       input_channels = output_channels

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# Static annotations for mypy
       self.stage2: nn.Sequential
       self.stage3: nn.Sequential
       self.stage4: nn.Sequential

stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
       for name, repeats, output_channels in zip(stage_names, stages_repeats,
                                                 self._stage_out_channels[1:]):
           seq = [inverted_residual(input_channels, output_channels, 2)]
           for i in range(repeats - 1):
               seq.append(inverted_residual(output_channels, output_channels, 1))
           setattr(self, name, nn.Sequential(*seq))
           input_channels = output_channels

output_channels = self._stage_out_channels[-1]
       self.conv5 = nn.Sequential(
           nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
           nn.BatchNorm2d(output_channels),
           nn.ReLU(inplace=True)
       )

self.fc = nn.Linear(output_channels, num_classes)

def _forward_impl(self, x: Tensor) -> Tensor:
       # See note [TorchScript super()]
       x = self.conv1(x)
       x = self.maxpool(x)
       x = self.stage2(x)
       x = self.stage3(x)
       x = self.stage4(x)
       x = self.conv5(x)
       x = x.mean([2, 3])  # global pool
       x = self.fc(x)
       return x

def forward(self, x: Tensor) -> Tensor:
       return self._forward_impl(x)

二、train.py

Python深度学习之使用Pytorch搭建ShuffleNetv2

来源:https://blog.csdn.net/weixin_43154149/article/details/116267653

标签:Python,Pytorch,ShuffleNetv2
0
投稿

猜你喜欢

  • python 获取字典特定值对应的键的实现

    2022-07-01 19:25:21
  • 一文详解CORS与预检请求

    2024-06-10 16:04:19
  • 使用SQL Server2005扩展函数进行性能优化

    2010-06-07 11:26:00
  • 解决mysql登录错误:'Access denied for user 'root'@'localhost'

    2024-01-22 16:41:20
  • perl产生随机数实现代码

    2023-04-14 05:30:10
  • SQL Server日志清除的两种方法教程简介

    2008-05-04 20:59:00
  • JS实现页面滚动到关闭时的位置与不滚动效果

    2024-04-10 10:47:56
  • JS加载器如何动态加载外部js文件

    2024-04-16 08:47:06
  • 如何理解及使用Python闭包

    2021-12-22 23:50:59
  • 一起来看看python的装饰器代码

    2023-08-07 19:33:20
  • 如何基于Python深度图生成3D点云详解

    2022-03-08 16:41:11
  • Vue中父组件向子组件通信的方法

    2024-04-26 17:37:32
  • Python 转换文本编码实现解析

    2022-07-15 15:58:49
  • python字符串拼接.join()和拆分.split()详解

    2021-11-12 04:09:17
  • 介绍Python中的一些高级编程技巧

    2022-09-22 19:23:15
  • PHP CURL CURLOPT参数说明(curl_setopt)

    2023-11-14 19:06:44
  • python中not not x 与bool(x) 的区别

    2021-04-27 03:50:17
  • Python实现的远程文件自动打包并下载功能示例

    2021-03-25 04:45:51
  • MySQL8.0的WITH查询详情

    2024-01-24 16:43:16
  • Python自然语言处理之词干,词形与最大匹配算法代码详解

    2023-07-23 04:48:37
  • asp之家 网络编程 m.aspxhome.com