pytorch 利用lstm做mnist手写数字识别分类的实例
作者:xckkcxxck 发布时间:2023-01-31 03:15:38
标签:pytorch,lstm,mnist,手写,识别
代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
import sys
sys.path.append('..')
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
#定义数据
data_tf = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
#定义模型
class rnn_classify(nn.Module):
def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
super(rnn_classify, self).__init__()
self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
def forward(self, x):
#x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
out = self.classifier(out) #得到分类结果
return out
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
#定义训练过程
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().item()
return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
if torch.cuda.is_available():
net = net.cuda()
prev_time = datetime.datetime.now()
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
for im, label in train_data:
if torch.cuda.is_available():
im = Variable(im.cuda()) # (bs, 3, h, w)
label = Variable(label.cuda()) # (bs, h, w)
else:
im = Variable(im)
label = Variable(label)
# forward
output = net(im)
loss = criterion(output, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += get_acc(output, label)
cur_time = datetime.datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
if valid_data is not None:
valid_loss = 0
valid_acc = 0
net = net.eval()
for im, label in valid_data:
if torch.cuda.is_available():
im = Variable(im.cuda())
label = Variable(label.cuda())
else:
im = Variable(im)
label = Variable(label)
output = net(im)
loss = criterion(output, label)
valid_loss += loss.item()
valid_acc += get_acc(output, label)
epoch_str = (
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_loss / len(valid_data),
valid_acc / len(valid_data)))
else:
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
(epoch, train_loss / len(train_data),
train_acc / len(train_data)))
prev_time = cur_time
print(epoch_str + time_str)
train(net, train_data, test_data, 10, optimizer, criterion)
来源:https://blog.csdn.net/xckkcxxck/article/details/82978942
0
投稿
猜你喜欢
- 项目中使用的日志库是使用python官方库logging封装的,但是居然一直么有设置日志自动滚动,经常会受到告警说哪台机器磁盘空间又满,清理
- 以下是一个类文件,下面的注解是调用类的方法注意:如果系统不支持建立Scripting.FileSystemObject对象,那么数据库压缩功
- 如下所示:#-*- coding: utf-8 -*-import pandas as pdimport numpy as npfrom p
- 概述Rollup, 和 Webpack, Parcel 都是模块打包工具(module bundler tool), 但是侧重点不同, 我们
- 在MySQL中有两种方法1、create table t_name select ...2、create table t_name like
- 生成器是迭代器,同时也并不仅仅是迭代器,不过迭代器之外的用途实在是不多,所以我们可以大声地说:生成器提供了非常方便的自定义迭代器的途径。这是
- 一、基础介绍Go 是静态(编译型)语言,是区别于解释型语言的弱类型语言(静态:类型固定,强类型:不同类型不允许直接运算)例如 python
- 今天来说一下,有些刚刚接触python的朋友,在使用pip install安装python 第三方库的过程中会出现网速很慢,或者是安装下载到
- SQLServer分页方式附带50万数据分页时间[本机访问|已重启SQL服务|无其他程序干扰][非索引排序]环境 WIN7 SQL服务12.
- 实例如下:from win32com.client import Dispatch import win32com.client
- 您想更改网站博客页面上 WordPress 文章的顺序吗?那么您就在正确的地方学习此功能并更改文章的顺序。因此,在本文中,我将向您展示如何通
- 本文介绍了从golang语言中fmt包从标准输入获取数据的Scan系列函数、从io.Reader中获取数据的Fscan系列函数以及从字符串中
- 一个不错的文字放大特性源码。效果图:运行代码框<script for=document event=onmousemove>//
- 在开始之前,我们先来看看uint 与 int 的区别上面是图,下面是源码:package main import ( "fmt&q
- Some readers have asked to me what
- 1.查看binlog是否开启show variables like '%log_bin%';2.查看数据文件存放路径:bin
- 1.什么是getters在介绍state中我们了解到,在Store仓库里,state就是用来存放数据,若是对数据进行处理输出,比如数据要过滤
- with 语句是从 Python 2.5 开始引入的一种与异常处理相关的功能(2.5 版本中要通过 from __future__ impo
- 在select语句中可以使用groupby子句将行划分成较小的组,然后,使用聚组函数返回每一个组的汇总信息,另外,可以使用having子句限
- 边缘在人类视觉和计算机视觉中均起着重要的作用。人类能够仅凭一张背景剪影或一个草图就识别出物体类型和姿态。其中OpenCV提供了许多边缘检测滤