如何使用Pytorch搭建模型

作者:颀周 时间:2022-07-18 10:34:27 

1  模型定义

和TF很像,Pytorch也通过继承父类来搭建模型,同样也是实现两个方法。在TF中是__init__()和call(),在Pytorch中则是__init__()和forward()。功能类似,都分别是初始化模型内部结构和进行推理。其它功能比如计算loss和训练函数,你也可以继承在里面,当然这是可选的。下面搭建一个判别MNIST手写字的Demo,首先给出模型代码:


import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn,optim
from torchsummary import summary
from keras.datasets import mnist
from keras.utils import to_categorical
device = torch.device('cuda') #——————1——————

class ModelTest(nn.Module):
def __init__(self,device):
 super().__init__()
 self.layer1 = nn.Sequential(nn.Flatten(),nn.Linear(28*28,512),nn.ReLU())#——————2——————
 self.layer2 = nn.Sequential(nn.Linear(512,512),nn.ReLU())
 self.layer3 = nn.Sequential(nn.Linear(512,512),nn.ReLU())
 self.layer4 = nn.Sequential(nn.Linear(512,10),nn.Softmax())

self.to(device) #——————3——————
 self.opt = optim.SGD(self.parameters(),lr=0.01)#——————4——————
def forward(self,inputs): #——————5——————
 x = self.layer1(inputs)
 x = self.layer2(x)
 x = self.layer3(x)
 x = self.layer4(x)
 return x
def get_loss(self,true_labels,predicts):
 loss = -true_labels * torch.log(predicts) #——————6——————
 loss = torch.mean(loss)
 return loss
def train(self,imgs,labels):
 predicts = model(imgs)
 loss = self.get_loss(labels,predicts)
 self.opt.zero_grad()#——————7——————
 loss.backward()#——————8——————
 self.opt.step()#——————9——————
model = ModelTest(device)
summary(model,(1,28,28),3,device='cuda') #——————10——————

#1:获取设备,以方便后面的模型与变量进行内存迁移,设备名只有两种:'cuda'和'cpu'。通常是在你有GPU的情况下需要这样显式进行设备的设置,从而在需要时,你可以将变量从主存迁移到显存中。如果没有GPU,不获取也没事,pytorch会默认将参数都保存在主存中。

#2:模型中层的定义,可以使用Sequential将想要统一管理的层集中表示为一层。

#3:在初始化中将模型参数迁移到GPU显存中,加速运算,当然你也可以在需要时在外部执行model.to(device)进行迁移。

#4:定义模型的优化器,和TF不同,pytorch需要在定义时就将需要梯度下降的参数传入,也就是其中的self.parameters(),表示当前模型的所有参数。实际上你不用担心定义优化器和模型参数的顺序问题,因为self.parameters()的输出并不是模型参数的实例,而是整个模型参数对象的指针,所以即使你在定义优化器之后又定义了一个层,它依然能优化到。当然优化器你也可以在外部定义,传入model.parameters()即可。这里定义了一个随机梯度下降。

#5:模型的前向传播,和TF的call()类似,定义好model()所执行的就是这个函数。

#6:我将获取loss的函数集成在了模型中,这里计算的是真实标签和预测标签之间的交叉熵。

#7/8/9:在TF中,参数梯度是保存在梯度带中的,而在pytorch中,参数梯度是各自集成在对应的参数中的,可以使用tensor.grad来查看。每次对loss执行backward(),pytorch都会将参与loss计算的所有可训练参数关于loss的梯度叠加进去(直接相加)。所以如果我们没有叠加梯度的意愿的话,那就要在backward()之前先把之前的梯度删除。又因为我们前面已经把待训练的参数都传入了优化器,所以,对优化器使用zero_grad(),就能把所有待训练参数中已存在的梯度都清零。那么梯度叠加什么时候用到呢?比如批量梯度下降,当内存不够直接计算整个批量的梯度时,我们只能将批量分成一部分一部分来计算,每算一个部分得到loss就backward()一次,从而得到整个批量的梯度。梯度计算好后,再执行优化器的step(),优化器根据可训练参数的梯度对其执行一步优化。

#10:使用torchsummary函数显示模型结构。奇怪为什么不把这个继承在torch里面,要重新安装一个torchsummary库。

2  训练及可视化

接下来使用模型进行训练,因为pytorch自带的MNIST数据集并不好用,所以我使用的是Keras自带的,定义了一个获取数据的生成器。下面是完整的训练及绘图代码(50次迭代记录一次准确率):


import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn,optim
from torchsummary import summary
from keras.datasets import mnist
from keras.utils import to_categorical
device = torch.device('cuda') #——————1——————

class ModelTest(nn.Module):
def __init__(self,device):
 super().__init__()
 self.layer1 = nn.Sequential(nn.Flatten(),nn.Linear(28*28,512),nn.ReLU())#——————2——————
 self.layer2 = nn.Sequential(nn.Linear(512,512),nn.ReLU())
 self.layer3 = nn.Sequential(nn.Linear(512,512),nn.ReLU())
 self.layer4 = nn.Sequential(nn.Linear(512,10),nn.Softmax())

self.to(device) #——————3——————
 self.opt = optim.SGD(self.parameters(),lr=0.01)#——————4——————
def forward(self,inputs): #——————5——————
 x = self.layer1(inputs)
 x = self.layer2(x)
 x = self.layer3(x)
 x = self.layer4(x)
 return x
def get_loss(self,true_labels,predicts):
 loss = -true_labels * torch.log(predicts) #——————6——————
 loss = torch.mean(loss)
 return loss
def train(self,imgs,labels):
 predicts = model(imgs)
 loss = self.get_loss(labels,predicts)
 self.opt.zero_grad()#——————7——————
 loss.backward()#——————8——————
 self.opt.step()#——————9——————
def get_data(device,is_train = True, batch = 1024, num = 10000):
train_data,test_data = mnist.load_data()
if is_train:
 imgs,labels = train_data
else:
 imgs,labels = test_data
imgs = (imgs/255*2-1)[:,np.newaxis,...]
labels = to_categorical(labels,10)
imgs = torch.tensor(imgs,dtype=torch.float32).to(device)
labels = torch.tensor(labels,dtype=torch.float32).to(device)
i = 0
while(True):
 i += batch
 if i > num:
  i = batch
 yield imgs[i-batch:i],labels[i-batch:i]
train_dg = get_data(device, True,batch=4096,num=60000)
test_dg = get_data(device, False,batch=5000,num=10000)

model = ModelTest(device)
summary(model,(1,28,28),11,device='cuda')
ACCs = []
import time
start = time.time()
for j in range(20000):
#训练
imgs,labels = next(train_dg)
model.train(imgs,labels)

#验证
img,label = next(test_dg)
predicts = model(img)
acc = 1 - torch.count_nonzero(torch.argmax(predicts,axis=1) - torch.argmax(label,axis=1))/label.shape[0]
if j % 50 == 0:
 t = time.time() - start
 start = time.time()
 ACCs.append(acc.cpu().numpy())
 print(j,t,'ACC: ',acc)
#绘图
x = np.linspace(0,len(ACCs),len(ACCs))
plt.plot(x,ACCs)

准确率变化图如下:

如何使用Pytorch搭建模型

3   注意事项

需要注意的是,pytorch的tensor基于numpy的array,它们是共享内存的。也就是说,如果你把tensor直接插入一个列表,当你修改这个tensor时,列表中的这个tensor也会被修改;更容易被忽略的是,即使你用tensor.detach.numpy(),先将tensor转换为array类型,再插入列表,当你修改原本的tensor时,列表中的这个array也依然会被修改。所以如果我们只是想保存tensor的值而不是整个对象,就要使用np.array(tensor)将tensor的值复制出来。

来源:https://www.cnblogs.com/qizhou/p/13870937.html?utm_source=tuicool&utm_medium=referral

标签:Pytorch,搭建,模型
0
投稿

猜你喜欢

  • Web开发与JavaScript编辑利器——Aptana Studio简介

    2008-05-05 13:32:00
  • python实现学生信息管理系统(面向对象)

    2022-06-13 16:16:07
  • SQL Server中处理空值时涉及的三问题

    2009-01-20 11:24:00
  • SQL处理多级分类,查询结果呈树形结构

    2012-08-21 10:50:12
  • python中使用xlrd、xlwt操作excel表格详解

    2023-06-25 03:59:51
  • python记录程序运行时间的三种方法

    2023-08-25 03:12:19
  • PHP实现根据数组某个键值大小进行排序的方法

    2023-11-15 00:35:55
  • Python 提速器numba

    2021-04-27 21:14:06
  • Oracle平台应用数据库系统的设计与开发

    2010-07-21 13:03:00
  • sql如何在线创建新表?

    2010-06-22 21:21:00
  • pytorch 图像中的数据预处理和批标准化实例

    2023-07-16 15:08:12
  • Python数据分析基础之文件的读取

    2022-10-16 21:25:21
  • dl+ol应用

    2008-06-21 17:04:00
  • FF和IE之间7个JavaScript的差异[译]

    2009-05-04 18:19:00
  • 提高CSS代码的可读性

    2008-05-11 18:59:00
  • python 以16进制打印输出的方法

    2023-10-23 07:33:17
  • 解析PHP中一些可能会被忽略的问题

    2023-09-05 14:07:37
  • 200 行python 代码实现 2048 游戏

    2021-08-06 16:17:35
  • 分支任务:从哪里来,回哪里去

    2009-09-04 18:58:00
  • 使用apidoc管理RESTful风格Flask项目接口文档方法

    2022-11-24 10:05:14
  • asp之家 网络编程 m.aspxhome.com