人工智能学习pyTorch自建数据集及可视化结果实现过程
作者:Swayzzu 发布时间:2022-08-04 14:54:33
一、自定义数据集
现有数据如下:
5个文件夹,每个文件夹是神奇宝贝的一种。
每个图片形状、大小、格式不一。
我们训练CNN的时候需要的是tensor类型的数据,因此需要将所有的图片进行下列转换:
1.对文件夹编号,进行映射,比如妙蛙种子文件夹编号0,皮卡丘编号1等。
2.对文件夹中所有图片,进行编号的对应,这个就是标签。并保存为一个csv文件。
3.图片信息获取:分为train,val,test
4.处理图片,使其成为torch可以处理的类型
1.文件夹映射
前半部分为文件夹的映射。我们希望传入数据的时候直接传入文件夹的名字,而文件夹所在的路径就是py文件所在的路径,因此这样可以直接读取。对于路径的操作使用os.path.join进行。
2.图片对应标签
输入的filename,就是我们将图片和标签信息存储的文件。
使用glob.glob方法,可以轻松调取路径下的所有指定类型的文件。
将名字和标签对应好后,通过csv.writer,可以将信息以csv格式写入新文件。
以上是保存的部分,在这个函数中,我们还要重新读取一下这个文件,因为要在这个类中获得最终的图片,以及标签,并且返回。
3.训练及测试数据分割
这里是第一步的图片的后半部分,导入了图片之后,对其进行分割,这里是按照训练、交叉验证、测试,分别是0.6,0.2,0.2进行分割的。
分割完毕后的self.images, self.labels,就可以拿来进行tensor相关的处理了。
4.数据处理
上面几步是准备工作,接下来定义的__getitem__是为了能够使train_loader = DataLoader()这一语句实现。在这里面直接将数据进行我们希望进行的转换。比如大小、旋转、裁剪等。
最后返回处理好的图片,以及tensor化的标签。
另外,还需要定义一个__len__,使得我们可以获得数据集长度。
二、ResNet处理
我们要用ResNet对图片进行处理,因此其中的参数需要进行一定的修改。
主要的修改部分是ResNet18之中的resblock模块。因为我们希望输入的是3通道,224*224的图片,因此在这里对通道,步长进行一定的修改,并进行测试,成功之后便可以进行训练了。
三、训练及可视化
1.数据集导入
同时把GPU设备相关代码准备好,并且由于需要可视化,因此先实例化visdom,并且在终端上输入python -m visdom.server,打开visdom监视终端。
2.测试函数
先把模式改为eval(),接下来就是通过model,去训练测试集,得到标签,并统计正确率。
3.训练过程及可视化
和之前的一样,还是先实例化一个优化器,选择损失函数模式,实例化ResNet18,然后进行训练。
在这里由于要展示,因此先对损失值,交叉验证分数分别设置一个初始的线,通过append的方法,画出我们的损失曲线,以及交叉验证分数曲线。
通过torch.save方法存储我们的最优解。
最后通过把存储好的最优解调用起来,使用测试集,来测试最终的效果。
最终获得的交叉验证准确率89%,测试集准确率88%,损失值及交叉验证结果的图像如下:
来源:https://blog.csdn.net/Swayzzu/article/details/121164368


猜你喜欢
- 我们一般都认为TRUNCATE是一种不可回滚的操作,它会删除表中的所有数据以及重置Identity列。如果你在事务中进行TRUNCATE操作
- 本文实例讲述了Python使用matplotlib和pandas实现的画图操作。分享给大家供大家参考,具体如下:画图在工作再所难免,尤其在做
- 搜索引擎是通过分析网页源代码来分析页面文本信息的逻辑性,所以在编写网页代码的时候一定要尽可能使用合适的标签来体现文本表达的层次感,也即是让搜
- 解决项目pycharm能运行,在终端却无法运行的问题报 ModuleNotFoundError: No module named '
- DBI安装:DBI详细信息参考:http://dbi.perl.org/ 1.下载DBI包: wget http://search.cpan
- Python Logging原来真的远比我想象的要复杂很多很多,学习路线堪比git。但是又绕不过去,alternatives又少,所以必须要
- MS SQL Server中文版的预设日期datetime格式是yyyy-mm-dd hh:mm:ss.mmm 长短日期格式 --短日期格式
- 1、下载mysql-python官网地址:http://sourceforge.net/projects/mysql-python/2、安装
- 如果看到特别感兴趣的抖音vlogger的视频,想全部dump下来,如何操作呢?下面介绍介绍如何使用python导出特定用户所有视频信息抓包分
- 刚开始,根据我的想法,这个很简单嘛,上sql语句delete from zqzrdp where tel in (select min(dp
- 实战场景 本篇博客学习字体反爬,涉及的站点是实习 x,目标站点地址直接百度搜索即可。可以看到右侧源码中出现了很多&ldqu
- 引言----在实际的web测试工作中,需要配合键盘按键来操作,webdriver的 keys()类提供键盘上所有按键的操作,还可以模拟组合键
- 实验目的:用socket 模拟一个微型的web服务器,当py脚本run起后,实微型web server架起了,然后用本地浏览器访问
- CKeditor编辑器是FCKeditor的升级版本想对于FCK来说,确实比较好用,加载速度也比较快以下是如果通过JS获取CKeditor编
- 本文分析了让ThinkPHP的模板引擎达到最佳效率的方法。分享给大家供大家参考,具体如下:默认情况下ThinkPHP框架系统默认使用的模板引
- 1.首先需要安装pandas, 安装的时候可能由依赖的包需要安装,根据运行时候的提示,缺少哪个库,就pip 安装哪个库。2.示例代码impo
- <script language=javascript> </script>
- 本文实例讲述了python实现与redis交互操作。分享给大家供大家参考,具体如下:相关内容:redis模块的使用设置值获取值安装模块导入模
- 使用 Beanstalkd 作为消息队列服务,然后结合 Python 的装饰器语法实现一个简单的异步任务处理工具.最终效果定义任务:from
- 计模式的目的是让代码易维护、易扩展,不能为了模式而模式,因此一个简单的工具脚本是不需要用到任何模式的。简单工厂模式又叫静态工厂方法模式,工厂