pytorch cnn 识别手写的字实现自建图片数据

作者:瓦力冫 时间:2023-04-18 02:39:22 

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:


# library
# standard library
import os
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate

root = "./mnist/raw/"

def default_loader(path):
 # return Image.open(path).convert('RGB')
 return Image.open(path)

class MyDataset(Dataset):
 def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
   fh = open(txt, 'r')
   imgs = []
   for line in fh:
     line = line.strip('\n')
     line = line.rstrip()
     words = line.split()
     imgs.append((words[0], int(words[1])))
   self.imgs = imgs
   self.transform = transform
   self.target_transform = target_transform
   self.loader = loader
   fh.close()
 def __getitem__(self, index):
   fn, label = self.imgs[index]
   img = self.loader(fn)
   img = Image.fromarray(np.array(img), mode='L')
   if self.transform is not None:
     img = self.transform(img)
   return img,label
 def __len__(self):
   return len(self.imgs)

train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)

class CNN(nn.Module):
 def __init__(self):
   super(CNN, self).__init__()
   self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
     nn.Conv2d(
       in_channels=1,       # input height
       out_channels=16,      # n_filters
       kernel_size=5,       # filter size
       stride=1,          # filter movement/step
       padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
     ),               # output shape (16, 28, 28)
     nn.ReLU(),           # activation
     nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
   )
   self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
     nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
     nn.ReLU(),           # activation
     nn.MaxPool2d(2),        # output shape (32, 7, 7)
   )
   self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes

def forward(self, x):
   x = self.conv1(x)
   x = self.conv2(x)
   x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
   output = self.out(x)
   return output, x  # return x for visualization
cnn = CNN()
print(cnn) # net architecture

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted

# training and testing
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
   b_x = Variable(x)  # batch x
   b_y = Variable(y)  # batch y

output = cnn(b_x)[0]        # cnn output
   loss = loss_func(output, b_y)  # cross entropy loss
   optimizer.zero_grad()      # clear gradients for this training step
   loss.backward()         # backpropagation, compute gradients
   optimizer.step()        # apply gradients

if step % 50 == 0:
     cnn.eval()
     eval_loss = 0.
     eval_acc = 0.
     for i, (tx, ty) in enumerate(test_loader):
       t_x = Variable(tx)
       t_y = Variable(ty)
       output = cnn(t_x)[0]
       loss = loss_func(output, t_y)
       eval_loss += loss.data[0]
       pred = torch.max(output, 1)[1]
       num_correct = (pred == t_y).sum()
       eval_acc += float(num_correct.data[0])
     acc_rate = eval_acc / float(len(test_data))
     print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

来源:http://www.waitingfy.com/archives/3549

标签:pytorch,cnn,识别手写
0
投稿

猜你喜欢

  • MySQL8新特性之全局参数持久化详解

    2024-01-23 12:17:10
  • MySql版本问题sql_mode=only_full_group_by的完美解决方案

    2024-01-18 16:08:14
  • Python使用Dash开发网页应用的方法详解

    2022-04-26 03:45:14
  • WEB打印大全

    2023-06-30 14:35:15
  • Python的Flask框架及Nginx实现静态文件访问限制功能

    2023-08-13 03:13:38
  • python数据可视化JupyterLab实用扩展程序Mito

    2021-01-24 13:42:04
  • YUI Compressor快速使用指南

    2011-06-27 20:07:30
  • Vue.js中轻松解决v-for执行出错的三个方案

    2024-05-10 14:19:08
  • 分享十个Python超级好用提高工作效率的自动化脚本

    2021-06-26 17:17:16
  • MySQL查询出现1055错误的原因及解决方法

    2024-01-13 04:05:54
  • 使用Python的Django框架结合jQuery实现AJAX购物车页面

    2023-05-21 01:59:28
  • JavaScript Alert通用美化类

    2024-04-27 15:20:50
  • Python如何使用OS模块调用cmd

    2023-03-22 02:25:39
  • MySQL函数与存储过程字符串长度限制的解决

    2024-01-16 13:17:01
  • python中字符串的操作方法大全

    2023-10-01 17:47:15
  • 详解python的字典及相关操作

    2023-03-28 08:52:42
  • Django模型修改及数据迁移实现解析

    2022-05-20 10:20:40
  • Python调用graphviz绘制结构化图形网络示例

    2021-09-15 19:06:49
  • ASPImage组件的实现过程[图]

    2008-02-03 15:37:00
  • Python中的闭包总结

    2023-09-09 03:46:05
  • asp之家 网络编程 m.aspxhome.com