PyTorch 如何将CIFAR100数据按类标归类保存
作者:Xie_learning 时间:2023-01-10 06:01:03
few-shot learning的采样
Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,每一类抽取K个样本构成support set, query set则是在刚才抽取的N类剩余的样本中sample一定数量的样本(可以是均匀采样,也可以是不均匀采样)。
对数据按类标归类
针对上述情况,我们需要使用不同类别放置在不同文件夹的数据集。但有时,数据并没有按类放置,这时就需要对数据进行处理。
下面以CIFAR100为列(不含N-way-k-shot的采样):
import os
from skimage import io
import torchvision as tv
import numpy as np
import torch
def Cifar100(root):
character = [[] for i in range(100)]
train_set = tv.datasets.CIFAR100(root, train=True, download=True)
test_set = tv.datasets.CIFAR100(root, train=False, download=True)
dataset = []
for (X, Y) in zip(train_set.train_data, train_set.train_labels): # 将train_set的数据和label读入列表
dataset.append(list((X, Y)))
for (X, Y) in zip(test_set.test_data, test_set.test_labels): # 将test_set的数据和label读入列表
dataset.append(list((X, Y)))
for X, Y in dataset:
character[Y].append(X) # 32*32*3
character = np.array(character)
character = torch.from_numpy(character)
# 按类打乱
np.random.seed(6)
shuffle_class = np.arange(len(character))
np.random.shuffle(shuffle_class)
character = character[shuffle_class]
# shape = self.character.shape
# self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3]) # 将数据转成channel在前
meta_training, meta_validation, meta_testing = \
character[:64], character[64:80], character[80:] # meta_training : meta_validation : Meta_testing = 64类:16类:20类
dataset = [] # 释放内存
character = []
os.mkdir(os.path.join(root, 'meta_training'))
for i, per_class in enumerate(meta_training):
character_path = os.path.join(root, 'meta_training', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
os.mkdir(os.path.join(root, 'meta_validation'))
for i, per_class in enumerate(meta_validation):
character_path = os.path.join(root, 'meta_validation', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
os.mkdir(os.path.join(root, 'meta_testing'))
for i, per_class in enumerate(meta_testing):
character_path = os.path.join(root, 'meta_testing', 'character_' + str(i))
os.mkdir(character_path)
for j, img in enumerate(per_class):
img_path = character_path + '/' + str(j) + ".jpg"
io.imsave(img_path, img)
if __name__ == '__main__':
root = '/home/xie/文档/datasets/cifar_100'
Cifar100(root)
print("-----------------")
补充:使用Pytorch对数据集CIFAR-10进行分类
主要是以下几个步骤:
1、下载并预处理数据集
2、定义网络结构
3、定义损失函数和优化器
4、训练网络并更新参数
5、测试网络效果
#数据加载和预处理
#使用CIFAR-10数据进行分类实验
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage() # 可以把Tensor转成Image,方便可视化
#定义对数据的预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), #归一化
])
#训练集
trainset = tv.datasets.CIFAR10(
root = './data/',
train = True,
download = True,
transform = transform
)
trainloader = t.utils.data.DataLoader(
trainset,
batch_size = 4,
shuffle = True,
num_workers = 2,
)
#测试集
testset = tv.datasets.CIFAR10(
root = './data/',
train = False,
download = True,
transform = transform,
)
testloader = t.utils.data.DataLoader(
testset,
batch_size = 4,
shuffle = False,
num_workers = 2,
)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
初次下载需要一些时间,运行结束后,显示如下:
import torch.nn as nn
import torch.nn.functional as F
import time
start = time.time()#计时
#定义网络结构
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),2)
x = F.max_pool2d(F.relu(self.conv2(x)),2)
x = x.view(x.size()[0],-1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
显示net结构如下:
#定义优化和损失
loss_func = nn.CrossEntropyLoss() #交叉熵损失函数
optimizer = t.optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9)
#训练网络
for epoch in range(2):
running_loss = 0
for i,data in enumerate(trainloader,0):
inputs,labels = data
outputs = net(inputs)
loss = loss_func(outputs,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss +=loss.item()
if i%2000 ==1999:
print('epoch:',epoch+1,'|i:',i+1,'|loss:%.3f'%(running_loss/2000))
running_loss = 0.0
end = time.time()
time_using = end - start
print('finish training')
print('time:',time_using)
结果如下:
下一步进行使用测试集进行网络测试:
#测试网络
correct = 0 #定义的预测正确的图片数
total = 0#总共图片个数
with t.no_grad():
for data in testloader:
images,labels = data
outputs = net(images)
_,predict = t.max(outputs,1)
total += labels.size(0)
correct += (predict == labels).sum()
print('测试集中的准确率为:%d%%'%(100*correct/total))
结果如下:
简单的网络训练确实要比10%的比例高一点:)
在GPU中训练:
#在GPU中训练
device = t.device('cuda:0' if t.cuda.is_available() else 'cpu')
net.to(device)
images = images.to(device)
labels = labels.to(device)
output = net(images)
loss = loss_func(output,labels)
loss
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/Xie_learning/article/details/89365305
标签:PyTorch,CIFAR100,类标,保存
0
投稿
猜你喜欢
解决SQL Server的“此数据库没有有效所有者”问题
2024-01-16 22:21:58
MySQL复制的概述、安装、故障、技巧、工具(火丁分享)
2024-01-18 02:29:49
mysql连接的空闲时间超过8小时后 MySQL自动断开该连接解决方案
2024-01-15 21:35:26
在Python程序中操作MySQL的基本方法
2024-01-20 18:30:46
js打开新窗口方法整理
2024-04-10 16:13:05
Python基础之循环语句用法示例【for、while循环】
2022-06-03 19:37:50
python抓取京东商城手机列表url实例代码
2022-11-11 18:23:04
ASP编写计数器的优化方法
2009-01-21 19:46:00
Tornado Web Server框架编写简易Python服务器
2021-10-18 09:23:52
使用python生成大量数据写入es数据库并查询操作(2)
2024-01-14 02:10:27
Python配置虚拟环境图文步骤
2023-10-13 01:37:40
django文档学习之applications使用详解
2021-09-11 11:15:12
解决django xadmin主题不显示和只显示bootstrap2的问题
2022-11-24 14:41:56
python超参数优化的具体方法
2022-01-04 22:28:31
利用python实现查看溧阳的摄影圈
2021-09-05 21:33:16
Python实现数字的格式化输出
2021-10-11 18:11:27
在MySQL中用正则表达式替换数据库中的内容的方法
2024-01-17 02:51:57
python3.6.3安装图文教程 TensorFlow安装配置方法
2021-06-25 19:20:42
vscode中使用Autoprefixer3.0无效的解决方法
2023-10-05 11:03:46
Python科学画图代码分享
2023-08-19 07:06:25