pytorch实现mnist分类的示例讲解
作者:Hy云帆 时间:2022-03-30 09:17:19
torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。
torchvision.datasets中包含了以下数据集
MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10
torchvision.models
torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:
pytorch torchvision transform
对PIL.Image进行变换
from __future__ import print_function
import argparse #Python 命令行解析工具
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if __name__ == '__main__':
main()
来源:https://blog.csdn.net/KyrieHe/article/details/80516737
标签:pytorch,mnist分类
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
100行Python代码实现每天不同时间段定时给女友发消息
2023-07-11 20:32:56
![](https://img.aspxhome.com/file/2023/3/82833_0s.jpg)
oracle sqlplus 常用命令大全
2009-05-24 19:47:00
Python编程实现二分法和牛顿迭代法求平方根代码
2022-01-03 12:24:46
![](https://img.aspxhome.com/file/2023/7/112407_0s.png)
python unichr函数知识点总结
2022-02-03 11:48:31
python批量下载图片的三种方法
2023-08-23 00:00:05
python实现简单的五子棋游戏
2023-07-30 13:24:31
python 遗传算法求函数极值的实现代码
2023-08-29 11:36:11
Python查询缺失值的4种方法总结
2023-10-29 13:42:08
![](https://img.aspxhome.com/file/2023/1/90871_0s.png)
pytorch 模型可视化的例子
2023-06-13 08:24:34
![](https://img.aspxhome.com/file/2023/0/59480_0s.jpg)
在Python的Flask框架下使用sqlalchemy库的简单教程
2021-02-23 23:58:40
解决Python一行输出不显示的问题
2021-05-19 19:21:46
python递归函数绘制分形树的方法
2021-04-22 02:16:02
![](https://img.aspxhome.com/file/2023/9/67339_0s.jpg)
Python+OpenCV实现车牌字符分割和识别
2022-03-11 02:55:21
![](https://img.aspxhome.com/file/2023/7/121147_0s.jpg)
详解django中自定义标签和过滤器
2021-02-16 19:43:38
![](https://img.aspxhome.com/file/2023/3/75883_0s.png)
利用20行Python 代码实现加密通信
2023-04-22 06:18:54
Python使用configparser库读取配置文件
2022-12-21 20:22:56
Dreamweaver MX 2004 试用心得
2010-03-25 12:21:00
![](https://img.aspxhome.com/file/UploadPic/20071/200713110316860s.gif)
Python 利用argparse模块实现脚本命令行参数解析
2022-12-01 16:11:55
![](https://img.aspxhome.com/file/2023/7/86227_0s.png)
js+php实现静态页面实时调用用户登陆状态的方法
2023-10-09 22:32:45
Python列表去重的几种方法整理
2022-06-18 18:05:17
![](https://img.aspxhome.com/file/2023/8/120468_0s.png)