pytorch cnn 识别手写的字实现自建图片数据
作者:瓦力冫 时间:2023-04-18 02:39:22
本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:
# library
# standard library
import os
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001 # learning rate
root = "./mnist/raw/"
def default_loader(path):
# return Image.open(path).convert('RGB')
return Image.open(path)
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
fh.close()
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
img = Image.fromarray(np.array(img), mode='L')
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # input height
out_channels=16, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
output = self.out(x)
return output, x # return x for visualization
cnn = CNN()
print(cnn) # net architecture
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
# training and testing
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
b_x = Variable(x) # batch x
b_y = Variable(y) # batch y
output = cnn(b_x)[0] # cnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step % 50 == 0:
cnn.eval()
eval_loss = 0.
eval_acc = 0.
for i, (tx, ty) in enumerate(test_loader):
t_x = Variable(tx)
t_y = Variable(ty)
output = cnn(t_x)[0]
loss = loss_func(output, t_y)
eval_loss += loss.data[0]
pred = torch.max(output, 1)[1]
num_correct = (pred == t_y).sum()
eval_acc += float(num_correct.data[0])
acc_rate = eval_acc / float(len(test_data))
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》
结果如下:
来源:http://www.waitingfy.com/archives/3549
标签:pytorch,cnn,识别手写
0
投稿
猜你喜欢
MySQL8新特性之全局参数持久化详解
2024-01-23 12:17:10
MySql版本问题sql_mode=only_full_group_by的完美解决方案
2024-01-18 16:08:14
Python使用Dash开发网页应用的方法详解
2022-04-26 03:45:14
WEB打印大全
2023-06-30 14:35:15
Python的Flask框架及Nginx实现静态文件访问限制功能
2023-08-13 03:13:38
python数据可视化JupyterLab实用扩展程序Mito
2021-01-24 13:42:04
YUI Compressor快速使用指南
2011-06-27 20:07:30
Vue.js中轻松解决v-for执行出错的三个方案
2024-05-10 14:19:08
分享十个Python超级好用提高工作效率的自动化脚本
2021-06-26 17:17:16
MySQL查询出现1055错误的原因及解决方法
2024-01-13 04:05:54
使用Python的Django框架结合jQuery实现AJAX购物车页面
2023-05-21 01:59:28
JavaScript Alert通用美化类
2024-04-27 15:20:50
Python如何使用OS模块调用cmd
2023-03-22 02:25:39
MySQL函数与存储过程字符串长度限制的解决
2024-01-16 13:17:01
python中字符串的操作方法大全
2023-10-01 17:47:15
详解python的字典及相关操作
2023-03-28 08:52:42
Django模型修改及数据迁移实现解析
2022-05-20 10:20:40
Python调用graphviz绘制结构化图形网络示例
2021-09-15 19:06:49
ASPImage组件的实现过程[图]
2008-02-03 15:37:00
Python中的闭包总结
2023-09-09 03:46:05