利用Pytorch实现ResNet网络构建及模型训练

作者:实力 时间:2022-02-24 19:57:59 

构建网络

ResNet由一系列堆叠的残差块组成,其主要作用是通过无限制地增加网络深度,从而使其更加强大。在建立ResNet模型之前,让我们先定义4个层,每个层由多个残差块组成。这些层的目的是降低空间尺寸,同时增加通道数量。

以ResNet50为例,我们可以使用以下代码来定义ResNet网络:

class ResNet(nn.Module):
   def __init__(self, num_classes=1000):
       super().__init__()
       self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
       self.bn1 = nn.BatchNorm2d(64)
       self.relu = nn.ReLU(inplace
(续)
即模型需要在输入层加入一些 normalization 和激活层。
```python
import torch.nn.init as init
class Flatten(nn.Module):
   def __init__(self):
       super().__init__()
   def forward(self, x):
       return x.view(x.size(0), -1)
class ResNet(nn.Module):
   def __init__(self, num_classes=1000):
       super().__init__()
       self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
       self.bn1 = nn.BatchNorm2d(64)
       self.relu = nn.ReLU(inplace=True)
       self.layer1 = nn.Sequential(
           ResidualBlock(64, 256, stride=1),
           *[ResidualBlock(256, 256) for _ in range(1, 3)]
       )
       self.layer2 = nn.Sequential(
           ResidualBlock(256, 512, stride=2),
           *[ResidualBlock(512, 512) for _ in range(1, 4)]
       )
       self.layer3 = nn.Sequential(
           ResidualBlock(512, 1024, stride=2),
           *[ResidualBlock(1024, 1024) for _ in range(1, 6)]
       )
       self.layer4 = nn.Sequential(
           ResidualBlock(1024, 2048, stride=2),
           *[ResidualBlock(2048, 2048) for _ in range(1, 3)]
       )
       self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
       self.flatten = Flatten()
       self.fc = nn.Linear(2048, num_classes)
       for m in self.modules():
           if isinstance(m, nn.Conv2d):
               init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
           elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
               init.constant_(m.weight, 1)
               init.constant_(m.bias, 0)
   def forward(self, x):
       x = self.conv1(x)
       x = self.bn1(x)
       x = self.relu(x)
       x = self.layer1(x)
       x = self.layer2(x)
       x = self.layer3(x)
       x = self.layer4(x)
       x = self.avgpool(x)
       x = self.flatten(x)
       x = self.fc(x)
       return x

改进点如下:

  • 我们使用nn.Sequential组件,将多个残差块组合成一个功能块(layer)。这样可以方便地修改网络深度,并将其与其他层分离九更容易上手,例如迁移学习中重新训练顶部分类器时。

  • 我们在ResNet的输出层添加了标准化和激活函数。它们有助于提高模型的收敛速度并改善性能。

  • 对于nn.Conv2d和批标准化层等神经网络组件,我们使用了PyTorch中的内置初始化函数。它们会自动为我们设置好每层的参数。

  • 我们还添加了一个Flatten层,将4维输出展平为2维张量,以便通过接下来的全连接层进行分类。

训练模型

我们现在已经实现了ResNet50模型,接下来我们将解释如何训练和测试该模型。

首先我们需要定义损失函数和优化器。在这里,我们使用交叉熵损失函数,以及Adam优化器。

import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(num_classes=1000).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

在使用PyTorch进行训练时,我们通常会创建一个循环,为每个批次的输入数据计算损失并对模型参数进行更新。以下是该循环的代码:

def train(model, optimizer, criterion, train_loader, device):
   model.train()
   train_loss = 0
   correct = 0
   total = 0
   for batch_idx, (inputs, targets) in enumerate(train_loader):
       inputs, targets = inputs.to(device), targets.to(device)
       optimizer.zero_grad()
       outputs = model(inputs)
       loss = criterion(outputs, targets)
       loss.backward()
       optimizer.step()
       train_loss += loss.item()
       _, predicted = outputs.max(1)
       total += targets.size(0)
       correct += predicted.eq(targets).sum().item()
   acc = 100 * correct / total
   avg_loss = train_loss / len(train_loader)
   return acc, avg_loss

在上面的训练循环中,我们首先通过model.train()代表进入训练模式。然后使用optimizer.zero_grad()清除

来源:https://juejin.cn/post/7222862599851475001

标签:Pytorch,ResNet,构建网络,模型训练
0
投稿

猜你喜欢

  • python超详细实现完整学生成绩管理系统

    2022-08-25 08:59:08
  • python 递归调用返回None的问题及解决方法

    2022-12-21 05:52:56
  • Python+Tkinter制作猜灯谜小游戏

    2021-09-24 19:43:17
  • python unicodedata模块用法

    2021-04-05 20:53:55
  • Pytorch转onnx、torchscript方式

    2022-05-03 11:10:43
  • asp无限级分类加js收缩伸展功能代码

    2009-12-08 12:25:00
  • 随机抽取的sql语句 每班任意抽取3名学生

    2024-01-27 10:12:26
  • 使用pyqt5 实现ComboBox的鼠标点击触发事件

    2022-01-12 17:24:57
  • asp如何让计数器只对新进用户计数?

    2010-05-13 16:36:00
  • Python搜索引擎实现原理和方法

    2023-06-26 05:35:32
  • Python字典和列表性能之间的比较

    2022-08-08 12:49:58
  • python下MySQLdb用法实例分析

    2024-01-18 11:50:27
  • show一下刚做的系统登录界面

    2008-09-13 19:13:00
  • python根据txt文本批量创建文件夹

    2021-12-18 21:24:52
  • python使用pip安装SciPy、SymPy、matplotlib教程

    2022-03-05 01:46:12
  • JS基于封装函数实现的表格分页完整示例

    2024-04-25 13:15:12
  • Vue 全部生命周期组件梳理整理

    2023-07-02 16:32:44
  • Python中dumps与dump及loads与load的区别

    2021-10-01 09:13:20
  • 实现Windows下设置定时任务来运行python脚本

    2021-10-12 05:03:32
  • ES6新特性一: let和const命令详解

    2024-05-22 10:37:14
  • asp之家 网络编程 m.aspxhome.com