使用pytorch完成kaggle猫狗图像识别方式
作者:charsenhz 时间:2023-04-10 08:39:01
kaggle是一个为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,在这上面有非常多的好项目、好资源可供机器学习、深度学习爱好者学习之用。
碰巧最近入门了一门非常的深度学习框架:pytorch,所以今天我和大家一起用pytorch实现一个图像识别领域的入门项目:猫狗图像识别。
深度学习的基础就是数据,咱们先从数据谈起。此次使用的猫狗分类图像一共25000张,猫狗分别有12500张,我们先来简单的瞅瞅都是一些什么图片。
我们从下载文件里可以看到有两个文件夹:train和test,分别用于训练和测试。以train为例,打开文件夹可以看到非常多的小猫图片,图片名字从0.jpg一直编码到9999.jpg,一共有10000张图片用于训练。
而test中的小猫只有2500张。仔细看小猫,可以发现它们姿态不一,有的站着,有的眯着眼睛,有的甚至和其他可识别物体比如桶、人混在一起。
同时,小猫们的图片尺寸也不一致,有的是竖放的长方形,有的是横放的长方形,但我们最终需要是合理尺寸的正方形。小狗的图片也类似,在这里就不重复了。
紧接着我们了解一下特别适用于图像识别领域的神经网络:卷积神经网络。学习过神经网络的同学可能或多或少地听说过卷积神经网络。这是一种典型的多层神经网络,擅长处理图像特别是大图像的相关机器学习问题。
卷积神经网络通过一系列的方法,成功地将大数据量的图像识别问题不断降维,最终使其能够被训练。CNN最早由Yann LeCun提出并应用在手写体识别上。
一个典型的CNN网络架构如下:
这是一个典型的CNN架构,由卷基层、池化层、全连接层组合而成。其中卷基层与池化层配合,组成多个卷积组,逐层提取特征,最终完成分类。
听到上述一连串的术语如果你有点蒙了,也别怕,因为这些复杂、抽象的技术都已经在pytorch中一一实现,我们要做的不过是正确的调用相关函数,
我在粘贴代码后都会做更详细、易懂的解释。
import os
import shutil
import torch
import collections
from torchvision import transforms,datasets
from __future__ import print_function, division
import os
import torch
import pylab
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
一个正常的CNN项目所需要的库还是蛮多的。
import math
from PIL import Image
class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
# assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img):
w,h = img.size
min_edge = min(img.size)
rate = min_edge / self.size
new_w = math.ceil(w / rate)
new_h = math.ceil(h / rate)
return img.resize((new_w,new_h))
这个称为Resize的库用于给图像进行缩放操作,本来是不需要亲自定义的,因为transforms.Resize已经实现这个功能了,但是由于目前还未知的原因,我的库里没有提供这个函数,所以我需要亲自实现用来代替transforms.Resize。
如果你的torch里面已经有了这个Resize函数就不用像我这样了。
data_transform = transforms.Compose([
Resize(84),
transforms.CenterCrop(84),
transforms.ToTensor(),
transforms.Normalize(mean = [0.5,0.5,0.5],std = [0.5,0.5,0.5])
])
train_dataset = datasets.ImageFolder(root = 'train/',transform = data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 4,shuffle = True,num_workers = 4)
test_dataset = datasets.ImageFolder(root = 'test/',transform = data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = 4,shuffle = True,num_workers = 4)
transforms是一个提供针对数据(这里指的是图像)进行转化的操作库,Resize就是上上段代码提供的那个类,主要用于把一张图片缩放到某个尺寸,在这里我们把需求暂定为要把图像缩放到84 x 84这个级别,这个就是可供调整的参数,大家为部署好项目以后可以试着修改这个参数,比如改成200 x 200,你就发现你可以去玩一盘游戏了~_~。
CenterCrop用于从中心裁剪图片,目标是一个长宽都为84的正方形,方便后续的计算。
ToTenser()就比较重要了,这个函数的目的就是读取图片像素并且转化为0-1的数字。
Normalize作为垫底的一步也很关键,主要用于把图片数据集的数值转化为标准差和均值都为0.5的数据集,这样数据值就从原来的0到1转变为-1到1。
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16 * 18 * 18,800)
self.fc2 = nn.Linear(800,120)
self.fc3 = nn.Linear(120,2)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16 * 18 * 18)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
来源:https://blog.csdn.net/xcszjjh1991/article/details/79256576
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
Go语言指针用法详解
![](https://img.aspxhome.com/file/2023/3/97153_0s.png)
vue2.0 解决抽取公用js的问题
![](https://img.aspxhome.com/file/2023/0/123110_0s.jpg)
Hadoop分布式集群的搭建的方法步骤
asp如何实现页面延迟?
将python代码和注释分离的方法
sql格式化工具集合
python本地文件服务器实例教程
![](https://img.aspxhome.com/file/2023/7/64607_0s.jpg)
python ddt实现数据驱动
一文带你学会Python Flask框架设置响应头
![](https://img.aspxhome.com/file/2023/0/134860_0s.png)
详解mysql数据去重的三种方式
![](https://img.aspxhome.com/file/2023/2/69762_0s.png)
SQL Server中使用sp_password重置SA密码实例
python3 lambda表达式详解
python将txt文件读入为np.array的方法
基于np.arange与np.linspace细微区别(数据溢出问题)
![](https://img.aspxhome.com/file/2023/0/115870_0s.png)
利用Python计算KS的实例详解
Python中new方法的详解
python可视化大屏库big_screen示例详解
![](https://img.aspxhome.com/file/2023/7/87487_0s.jpg)