人工智能学习pyTorch自建数据集及可视化结果实现过程

作者:Swayzzu 时间:2022-08-04 14:54:33 

一、自定义数据集

现有数据如下:

5个文件夹,每个文件夹是神奇宝贝的一种。

人工智能学习pyTorch自建数据集及可视化结果实现过程

每个图片形状、大小、格式不一。

人工智能学习pyTorch自建数据集及可视化结果实现过程

我们训练CNN的时候需要的是tensor类型的数据,因此需要将所有的图片进行下列转换:

1.对文件夹编号,进行映射,比如妙蛙种子文件夹编号0,皮卡丘编号1等。

2.对文件夹中所有图片,进行编号的对应,这个就是标签。并保存为一个csv文件。

3.图片信息获取:分为train,val,test

4.处理图片,使其成为torch可以处理的类型

1.文件夹映射

前半部分为文件夹的映射。我们希望传入数据的时候直接传入文件夹的名字,而文件夹所在的路径就是py文件所在的路径,因此这样可以直接读取。对于路径的操作使用os.path.join进行。

人工智能学习pyTorch自建数据集及可视化结果实现过程

2.图片对应标签

输入的filename,就是我们将图片和标签信息存储的文件。

使用glob.glob方法,可以轻松调取路径下的所有指定类型的文件。

将名字和标签对应好后,通过csv.writer,可以将信息以csv格式写入新文件。

人工智能学习pyTorch自建数据集及可视化结果实现过程

以上是保存的部分,在这个函数中,我们还要重新读取一下这个文件,因为要在这个类中获得最终的图片,以及标签,并且返回。

人工智能学习pyTorch自建数据集及可视化结果实现过程

3.训练及测试数据分割

这里是第一步的图片的后半部分,导入了图片之后,对其进行分割,这里是按照训练、交叉验证、测试,分别是0.6,0.2,0.2进行分割的。

分割完毕后的self.images, self.labels,就可以拿来进行tensor相关的处理了。

人工智能学习pyTorch自建数据集及可视化结果实现过程

4.数据处理

上面几步是准备工作,接下来定义的__getitem__是为了能够使train_loader = DataLoader()这一语句实现。在这里面直接将数据进行我们希望进行的转换。比如大小、旋转、裁剪等。

最后返回处理好的图片,以及tensor化的标签。

人工智能学习pyTorch自建数据集及可视化结果实现过程

另外,还需要定义一个__len__,使得我们可以获得数据集长度。

人工智能学习pyTorch自建数据集及可视化结果实现过程

二、ResNet处理

我们要用ResNet对图片进行处理,因此其中的参数需要进行一定的修改。

主要的修改部分是ResNet18之中的resblock模块。因为我们希望输入的是3通道,224*224的图片,因此在这里对通道,步长进行一定的修改,并进行测试,成功之后便可以进行训练了。

人工智能学习pyTorch自建数据集及可视化结果实现过程

三、训练及可视化

1.数据集导入

同时把GPU设备相关代码准备好,并且由于需要可视化,因此先实例化visdom,并且在终端上输入python -m visdom.server,打开visdom监视终端。

人工智能学习pyTorch自建数据集及可视化结果实现过程

2.测试函数

先把模式改为eval(),接下来就是通过model,去训练测试集,得到标签,并统计正确率。

人工智能学习pyTorch自建数据集及可视化结果实现过程

3.训练过程及可视化

和之前的一样,还是先实例化一个优化器,选择损失函数模式,实例化ResNet18,然后进行训练。

在这里由于要展示,因此先对损失值,交叉验证分数分别设置一个初始的线,通过append的方法,画出我们的损失曲线,以及交叉验证分数曲线。

人工智能学习pyTorch自建数据集及可视化结果实现过程

人工智能学习pyTorch自建数据集及可视化结果实现过程

通过torch.save方法存储我们的最优解。

最后通过把存储好的最优解调用起来,使用测试集,来测试最终的效果。

人工智能学习pyTorch自建数据集及可视化结果实现过程

最终获得的交叉验证准确率89%,测试集准确率88%,损失值及交叉验证结果的图像如下:

人工智能学习pyTorch自建数据集及可视化结果实现过程

来源:https://blog.csdn.net/Swayzzu/article/details/121164368

标签:pytorch,人工智能,数据集,可视化
0
投稿

猜你喜欢

  • 基于Python实现模拟三体运动的示例代码

    2022-03-29 21:40:37
  • 使用python+Flask实现日志在web网页实时更新显示

    2021-03-15 10:16:30
  • Python中HMAC加密算法的应用

    2021-07-29 15:55:18
  • 详解Python中__str__和__repr__方法的区别

    2023-02-28 17:29:17
  • python使用turtle库绘制树

    2022-04-14 09:09:06
  • python类别数据数字化LabelEncoder VS OneHotEncoder区别

    2023-10-12 07:46:46
  • Python文件及目录操作实例详解

    2023-11-26 12:50:27
  • 基于keras输出中间层结果的2种实现方式

    2023-10-11 16:05:49
  • TensorFlow损失函数专题详解

    2023-08-17 10:12:13
  • Pandas把dataframe或series转换成list的方法

    2022-03-24 23:05:02
  • javascript实现延时显示提示框效果

    2024-04-25 13:10:42
  • 深入了解Go的interface{}底层原理实现

    2024-05-21 10:19:03
  • Nodejs之TCP服务端与客户端聊天程序详解

    2024-05-03 15:55:48
  • python3.5 + PyQt5 +Eric6 实现的一个计算器代码

    2021-02-27 17:00:28
  • 如何使用ASP来读写注册表

    2007-09-20 13:08:00
  • Python 识别12306图片验证码物品的实现示例

    2021-04-03 22:17:24
  • Django debug为True时,css加载失败的解决方案

    2022-05-07 01:17:53
  • 如何使用 Go 和 Excelize 构建电子表格

    2024-02-10 04:59:02
  • Macbook安装Python最新版本、GUI开发环境、图像处理、视频处理环境详解

    2021-11-16 17:07:04
  • Python实现的质因式分解算法示例

    2021-12-16 23:10:05
  • asp之家 网络编程 m.aspxhome.com