PyTorch实现卷积神经网络的搭建详解

作者:Bubbliiiing 时间:2021-03-12 16:36:03 

PyTorch中实现卷积的重要基础函数

1、nn.Conv2d:

nn.Conv2d在pytorch中用于实现卷积。

nn.Conv2d(
   in_channels=32,
   out_channels=64,
   kernel_size=3,
   stride=1,
   padding=1,
)

1、in_channels为输入通道数。

2、out_channels为输出通道数。

3、kernel_size为卷积核大小。

4、stride为步数。

5、padding为padding情况。

6、dilation表示空洞卷积情况。

2、nn.MaxPool2d(kernel_size=2)

nn.MaxPool2d在pytorch中用于实现最大池化。

具体使用方式如下:

MaxPool2d(kernel_size,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode=False)

1、kernel_size为池化核的大小

2、stride为步长

3、padding为填充情况

3、nn.ReLU()

nn.ReLU()用来实现Relu函数,实现非线性。

4、x.view()

x.view用于reshape特征层的形状。

全部代码

这是一个简单的CNN模型,用于预测mnist手写体。

import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
# 循环世代
EPOCH = 20
BATCH_SIZE = 50
# 下载mnist数据集
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,)
# (60000, 28, 28)
print(train_data.train_data.size())                
# (60000)
print(train_data.train_labels.size())              
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# (2000, 1, 28, 28)
# 标准化
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# 建立pytorch神经网络
class CNN(nn.Module):
   def __init__(self):
       super(CNN, self).__init__()
       #----------------------------#
       #   第一部分卷积
       #----------------------------#
       self.conv1 = nn.Sequential(
           nn.Conv2d(
               in_channels=1,
               out_channels=32,
               kernel_size=5,
               stride=1,
               padding=2,
               dilation=1
           ),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2),
       )
       #----------------------------#
       #   第二部分卷积
       #----------------------------#
       self.conv2 = nn.Sequential(
           nn.Conv2d(
               in_channels=32,
               out_channels=64,
               kernel_size=3,
               stride=1,
               padding=1,
               dilation=1
           ),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2),
       )
       #----------------------------#
       #   全连接+池化+全连接
       #----------------------------#
       self.ful1 = nn.Linear(64 * 7 * 7, 512)
       self.drop = nn.Dropout(0.5)
       self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax())
   #----------------------------#
   #   前向传播
   #----------------------------#  
   def forward(self, x):
       x = self.conv1(x)
       x = self.conv2(x)
       x = x.view(x.size(0), -1)
       x = self.ful1(x)
       x = self.drop(x)
       output = self.ful2(x)
       return output
cnn = CNN()
# 指定优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
# 指定loss函数
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
   for step, (b_x, b_y) in enumerate(train_loader):
       #----------------------------#
       #   计算loss并修正权值
       #----------------------------#  
       output = cnn(b_x)
       loss = loss_func(output, b_y)
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       #----------------------------#
       #   打印
       #----------------------------#  
       if step % 50 == 0:
           test_output = cnn(test_x)
           pred_y = torch.max(test_output, 1)[1].data.numpy()
           accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
           print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)

来源:https://blog.csdn.net/weixin_44791964/article/details/103658845

标签:PyTorch,卷积神经网络,神经网络
0
投稿

猜你喜欢

  • 兼容FF的图片切换代码

    2009-09-26 20:15:00
  • Python 选择排序中的树形选择排序

    2023-06-10 04:33:32
  • python使用celery实现异步任务执行的例子

    2021-01-06 03:27:20
  • javascript 改变字体大小方法集合

    2023-07-06 16:58:02
  • python数据分析近年比特币价格涨幅趋势分布

    2022-04-02 15:05:56
  • class和id命名探讨

    2007-10-16 17:55:00
  • python标准库random模块处理随机数

    2023-11-23 16:22:49
  • 搜索系统与导航系统的关系

    2009-09-08 12:44:00
  • asp单主键高效通用分页存储过程

    2009-02-23 19:21:00
  • Python减少循环层次和缩进的技巧分析

    2023-10-07 21:41:09
  • Django与DRF结合的全局异常处理方案详解

    2021-05-19 22:53:16
  • 如何用ASP创建日志文件

    2008-03-10 17:27:00
  • 对python使用telnet实现弱密码登录的方法详解

    2023-12-28 02:52:46
  • Python Numpy教程之排序,搜索和计数详解

    2021-10-31 05:30:21
  • Python制作CSDN免积分下载器

    2021-12-25 03:46:35
  • 在MySQL中使用XML数据—数据格式化

    2009-12-29 10:26:00
  • JavaScript实现带自动提示的文本框效果代码

    2011-02-05 11:13:00
  • pyqt4教程之messagebox使用示例分享

    2023-11-06 08:09:03
  • sublime text配置node.js调试(图文教程)

    2023-07-04 14:07:57
  • python读取csv文件示例(python操作csv)

    2023-02-28 23:12:02
  • asp之家 网络编程 m.aspxhome.com