简单易懂Pytorch实战实例VGG深度网络

作者:青盏 时间:2021-09-07 19:47:24 

模型VGG,数据集cifar。对照这份代码走一遍,大概就知道整个pytorch的运行机制。

来源

定义模型:


'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable

cfg = {
 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# 模型需继承nn.Module
class VGG(nn.Module):
# 初始化参数:
 def __init__(self, vgg_name):
   super(VGG, self).__init__()
   self.features = self._make_layers(cfg[vgg_name])
   self.classifier = nn.Linear(512, 10)

# 模型计算时的前向过程,也就是按照这个过程进行计算
 def forward(self, x):
   out = self.features(x)
   out = out.view(out.size(0), -1)
   out = self.classifier(out)
   return out

def _make_layers(self, cfg):
   layers = []
   in_channels = 3
   for x in cfg:
     if x == 'M':
       layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
     else:
       layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
             nn.BatchNorm2d(x),
             nn.ReLU(inplace=True)]
       in_channels = x
   layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
   return nn.Sequential(*layers)

# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())

定义训练过程:


'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar
from torch.autograd import Variable

# 获取参数
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# 获取数据集,并先进行预处理
print('==> Preparing data..')
# 图像预处理和增强
transform_train = transforms.Compose([
 transforms.RandomCrop(32, padding=4),
 transforms.RandomHorizontalFlip(),
 transforms.ToTensor(),
 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 继续训练模型或新建一个模型
if args.resume:
 # Load checkpoint.
 print('==> Resuming from checkpoint..')
 assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
 checkpoint = torch.load('./checkpoint/ckpt.t7')
 net = checkpoint['net']
 best_acc = checkpoint['acc']
 start_epoch = checkpoint['epoch']
else:
 print('==> Building model..')
 net = VGG('VGG16')
 # net = ResNet18()
 # net = PreActResNet18()
 # net = GoogLeNet()
 # net = DenseNet121()
 # net = ResNeXt29_2x64d()
 # net = MobileNet()
 # net = MobileNetV2()
 # net = DPN92()
 # net = ShuffleNetG2()
 # net = SENet18()

# 如果GPU可用,使用GPU
if use_cuda:
 # move param and buffer to GPU
 net.cuda()
 # parallel use GPU
 net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))
 # speed up slightly
 cudnn.benchmark = True

# 定义度量和优化
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# 训练阶段
def train(epoch):
 print('\nEpoch: %d' % epoch)
 # switch to train mode
 net.train()
 train_loss = 0
 correct = 0
 total = 0
 # batch 数据
 for batch_idx, (inputs, targets) in enumerate(trainloader):
   # 将数据移到GPU上
   if use_cuda:
     inputs, targets = inputs.cuda(), targets.cuda()
   # 先将optimizer梯度先置为0
   optimizer.zero_grad()
   # Variable表示该变量属于计算图的一部分,此处是图计算的开始处。图的leaf variable
   inputs, targets = Variable(inputs), Variable(targets)
   # 模型输出
   outputs = net(inputs)
   # 计算loss,图的终点处
   loss = criterion(outputs, targets)
   # 反向传播,计算梯度
   loss.backward()
   # 更新参数
   optimizer.step()
   # 注意如果你想统计loss,切勿直接使用loss相加,而是使用loss.data[0]。因为loss是计算图的一部分,如果你直接加loss,代表total loss同样属于模型一部分,那么图就越来越大
   train_loss += loss.data[0]
   # 数据统计
   _, predicted = torch.max(outputs.data, 1)
   total += targets.size(0)
   correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

# 测试阶段
def test(epoch):
 global best_acc
 # 先切到测试模型
 net.eval()
 test_loss = 0
 correct = 0
 total = 0
 for batch_idx, (inputs, targets) in enumerate(testloader):
   if use_cuda:
     inputs, targets = inputs.cuda(), targets.cuda()
   inputs, targets = Variable(inputs, volatile=True), Variable(targets)
   outputs = net(inputs)
   loss = criterion(outputs, targets)
   # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
   test_loss += loss.data[0]
   _, predicted = torch.max(outputs.data, 1)
   total += targets.size(0)
   correct += predicted.eq(targets.data).cpu().sum()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
     % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

# Save checkpoint.
 # 保存模型
 acc = 100.*correct/total
 if acc > best_acc:
   print('Saving..')
   state = {
     'net': net.module if use_cuda else net,
     'acc': acc,
     'epoch': epoch,
   }
   if not os.path.isdir('checkpoint'):
     os.mkdir('checkpoint')
   torch.save(state, './checkpoint/ckpt.t7')
   best_acc = acc

# 运行模型
for epoch in range(start_epoch, start_epoch+200):
 train(epoch)
 test(epoch)
 # 清除部分无用变量
 torch.cuda.empty_cache()

运行:

新模型:
python main.py --lr=0.01
旧模型继续训练:
python main.py --resume --lr=0.01

一些utility:


'''Some helper functions for PyTorch, including:
 - get_mean_and_std: calculate the mean and std value of dataset.
 - msr_init: net parameter initialization.
 - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init

def get_mean_and_std(dataset):
 '''Compute the mean and std value of dataset.'''
 dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
 mean = torch.zeros(3)
 std = torch.zeros(3)
 print('==> Computing mean and std..')
 for inputs, targets in dataloader:
   for i in range(3):
     mean[i] += inputs[:,i,:,:].mean()
     std[i] += inputs[:,i,:,:].std()
 mean.div_(len(dataset))
 std.div_(len(dataset))
 return mean, std

def init_params(net):
 '''Init layer parameters.'''
 for m in net.modules():
   if isinstance(m, nn.Conv2d):
     init.kaiming_normal(m.weight, mode='fan_out')
     if m.bias:
       init.constant(m.bias, 0)
   elif isinstance(m, nn.BatchNorm2d):
     init.constant(m.weight, 1)
     init.constant(m.bias, 0)
   elif isinstance(m, nn.Linear):
     init.normal(m.weight, std=1e-3)
     if m.bias:
       init.constant(m.bias, 0)

_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
 global last_time, begin_time
 if current == 0:
   begin_time = time.time() # Reset for new bar.

cur_len = int(TOTAL_BAR_LENGTH*current/total)
 rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

sys.stdout.write(' [')
 for i in range(cur_len):
   sys.stdout.write('=')
 sys.stdout.write('>')
 for i in range(rest_len):
   sys.stdout.write('.')
 sys.stdout.write(']')

cur_time = time.time()
 step_time = cur_time - last_time
 last_time = cur_time
 tot_time = cur_time - begin_time

L = []
 L.append(' Step: %s' % format_time(step_time))
 L.append(' | Tot: %s' % format_time(tot_time))
 if msg:
   L.append(' | ' + msg)

msg = ''.join(L)
 sys.stdout.write(msg)
 for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
   sys.stdout.write(' ')

# Go back to the center of the bar.
 for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
   sys.stdout.write('\b')
 sys.stdout.write(' %d/%d ' % (current+1, total))

if current < total-1:
   sys.stdout.write('\r')
 else:
   sys.stdout.write('\n')
 sys.stdout.flush()

def format_time(seconds):
 days = int(seconds / 3600/24)
 seconds = seconds - days*3600*24
 hours = int(seconds / 3600)
 seconds = seconds - hours*3600
 minutes = int(seconds / 60)
 seconds = seconds - minutes*60
 secondsf = int(seconds)
 seconds = seconds - secondsf
 millis = int(seconds*1000)

f = ''
 i = 1
 if days > 0:
   f += str(days) + 'D'
   i += 1
 if hours > 0 and i <= 2:
   f += str(hours) + 'h'
   i += 1
 if minutes > 0 and i <= 2:
   f += str(minutes) + 'm'
   i += 1
 if secondsf > 0 and i <= 2:
   f += str(secondsf) + 's'
   i += 1
 if millis > 0 and i <= 2:
   f += str(millis) + 'ms'
   i += 1
 if f == '':
   f = '0ms'
 return f

来源:https://blog.csdn.net/qq_16234613/article/details/79818370

标签:Pytorch,VGG,深度网络
0
投稿

猜你喜欢

  • Python图像滤波处理操作示例【基于ImageFilter类】

    2021-10-31 16:47:20
  • 初学者学习Python好还是Java好

    2021-03-16 21:48:32
  • 交互设计模式(三)-Tagging(标签)

    2009-10-19 20:46:00
  • Django前端BootCSS实现分页的方法

    2023-12-21 01:45:34
  • 简单的Vue SSR的示例代码

    2023-07-02 17:08:46
  • JavaScript中实现块作用域的方法

    2024-04-16 10:38:39
  • seaborn绘制双变量联合分布图示例详解

    2021-04-29 01:49:24
  • 详解Python列表赋值复制深拷贝及5种浅拷贝

    2022-07-16 16:22:41
  • python神经网络使用Keras构建RNN训练

    2021-07-19 21:12:15
  • 用Python进行websocket接口测试

    2022-03-02 09:44:22
  • mysql 5.1版本修改密码及远程登录mysql数据库的方法

    2024-01-17 15:38:02
  • python批量修改xml属性的实现方式

    2022-10-03 12:34:58
  • 解析PHP中empty is_null和isset的测试

    2023-11-18 17:39:06
  • JSP EL表达式详细介绍

    2023-07-02 22:32:32
  • PHP PDOStatement::rowCount讲解

    2023-06-06 12:24:04
  • SQL数据类型详解

    2024-01-13 02:01:30
  • Python asyncio异步编程简单实现示例

    2023-09-23 15:27:52
  • 详解nodejs express下使用redis管理session

    2024-05-11 09:51:40
  • 手把手教你使用Python解决简单的zip文件解压密码

    2021-01-20 10:08:41
  • php输出指定时间以前时间格式的方法

    2024-05-09 14:46:39
  • asp之家 网络编程 m.aspxhome.com