使用pytorch完成kaggle猫狗图像识别方式

作者:charsenhz 时间:2023-04-10 08:39:01 

kaggle是一个为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,在这上面有非常多的好项目、好资源可供机器学习、深度学习爱好者学习之用。

碰巧最近入门了一门非常的深度学习框架:pytorch,所以今天我和大家一起用pytorch实现一个图像识别领域的入门项目:猫狗图像识别。

深度学习的基础就是数据,咱们先从数据谈起。此次使用的猫狗分类图像一共25000张,猫狗分别有12500张,我们先来简单的瞅瞅都是一些什么图片。

我们从下载文件里可以看到有两个文件夹:train和test,分别用于训练和测试。以train为例,打开文件夹可以看到非常多的小猫图片,图片名字从0.jpg一直编码到9999.jpg,一共有10000张图片用于训练。

而test中的小猫只有2500张。仔细看小猫,可以发现它们姿态不一,有的站着,有的眯着眼睛,有的甚至和其他可识别物体比如桶、人混在一起。

同时,小猫们的图片尺寸也不一致,有的是竖放的长方形,有的是横放的长方形,但我们最终需要是合理尺寸的正方形。小狗的图片也类似,在这里就不重复了。

紧接着我们了解一下特别适用于图像识别领域的神经网络:卷积神经网络。学习过神经网络的同学可能或多或少地听说过卷积神经网络。这是一种典型的多层神经网络,擅长处理图像特别是大图像的相关机器学习问题。

卷积神经网络通过一系列的方法,成功地将大数据量的图像识别问题不断降维,最终使其能够被训练。CNN最早由Yann LeCun提出并应用在手写体识别上。

一个典型的CNN网络架构如下:

使用pytorch完成kaggle猫狗图像识别方式

这是一个典型的CNN架构,由卷基层、池化层、全连接层组合而成。其中卷基层与池化层配合,组成多个卷积组,逐层提取特征,最终完成分类。

听到上述一连串的术语如果你有点蒙了,也别怕,因为这些复杂、抽象的技术都已经在pytorch中一一实现,我们要做的不过是正确的调用相关函数,

我在粘贴代码后都会做更详细、易懂的解释。


import os
import shutil
import torch
import collections
from torchvision import transforms,datasets
from __future__ import print_function, division
import os
import torch
import pylab
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion() # interactive mode

一个正常的CNN项目所需要的库还是蛮多的。


import math
from PIL import Image

class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
 (h, w), output size will be matched to this. If size is an int,
 smaller edge of the image will be matched to this number.
 i.e, if height > width, then image will be rescaled to
 (size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
 ``PIL.Image.BILINEAR``
"""

def __init__(self, size, interpolation=Image.BILINEAR):
# assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation

def __call__(self, img):
w,h = img.size

min_edge = min(img.size)
rate = min_edge / self.size

new_w = math.ceil(w / rate)
new_h = math.ceil(h / rate)

return img.resize((new_w,new_h))

这个称为Resize的库用于给图像进行缩放操作,本来是不需要亲自定义的,因为transforms.Resize已经实现这个功能了,但是由于目前还未知的原因,我的库里没有提供这个函数,所以我需要亲自实现用来代替transforms.Resize。

如果你的torch里面已经有了这个Resize函数就不用像我这样了。


data_transform = transforms.Compose([
Resize(84),
transforms.CenterCrop(84),
transforms.ToTensor(),
transforms.Normalize(mean = [0.5,0.5,0.5],std = [0.5,0.5,0.5])
])

train_dataset = datasets.ImageFolder(root = 'train/',transform = data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 4,shuffle = True,num_workers = 4)

test_dataset = datasets.ImageFolder(root = 'test/',transform = data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = 4,shuffle = True,num_workers = 4)

transforms是一个提供针对数据(这里指的是图像)进行转化的操作库,Resize就是上上段代码提供的那个类,主要用于把一张图片缩放到某个尺寸,在这里我们把需求暂定为要把图像缩放到84 x 84这个级别,这个就是可供调整的参数,大家为部署好项目以后可以试着修改这个参数,比如改成200 x 200,你就发现你可以去玩一盘游戏了~_~。

CenterCrop用于从中心裁剪图片,目标是一个长宽都为84的正方形,方便后续的计算。

ToTenser()就比较重要了,这个函数的目的就是读取图片像素并且转化为0-1的数字。

Normalize作为垫底的一步也很关键,主要用于把图片数据集的数值转化为标准差和均值都为0.5的数据集,这样数据值就从原来的0到1转变为-1到1。


class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()

self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16 * 18 * 18,800)
self.fc2 = nn.Linear(800,120)
self.fc3 = nn.Linear(120,2)

def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16 * 18 * 18)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)

return x

net = Net()

来源:https://blog.csdn.net/xcszjjh1991/article/details/79256576

标签:pytorch,kaggle,图像识别
0
投稿

猜你喜欢

  • Go语言指针用法详解

    2023-08-05 17:06:36
  • vue2.0 解决抽取公用js的问题

    2024-05-28 15:59:28
  • Hadoop分布式集群的搭建的方法步骤

    2022-06-08 06:02:42
  • asp如何实现页面延迟?

    2010-06-03 10:18:00
  • 将python代码和注释分离的方法

    2022-04-06 12:04:50
  • sql格式化工具集合

    2024-01-14 02:15:14
  • python本地文件服务器实例教程

    2022-07-31 16:38:17
  • python ddt实现数据驱动

    2021-11-11 02:37:08
  • 一文带你学会Python Flask框架设置响应头

    2023-04-16 03:11:37
  • 详解mysql数据去重的三种方式

    2024-01-22 03:06:35
  • SQL Server中使用sp_password重置SA密码实例

    2024-01-20 15:07:59
  • python3 lambda表达式详解

    2021-03-01 20:28:20
  • python将txt文件读入为np.array的方法

    2023-07-23 08:10:29
  • 基于np.arange与np.linspace细微区别(数据溢出问题)

    2021-08-29 23:46:25
  • 利用Python计算KS的实例详解

    2021-10-16 12:24:09
  • Python中new方法的详解

    2022-12-06 14:20:18
  • python可视化大屏库big_screen示例详解

    2021-10-16 14:32:39
  • Python中int()函数的用法浅析

    2022-08-18 09:45:12
  • Pytorch中的gather使用方法

    2021-11-22 06:11:49
  • MySQL自定义函数简单用法示例

    2024-01-20 12:47:17
  • asp之家 网络编程 m.aspxhome.com