关于SSD目标检测模型的人脸口罩识别
作者:Mabel-mql 发布时间:2023-06-20 05:20:56
最近学习了SSD算法,了解了其基本的实现思路,并通过SSD模型训练自己的模型。
基本环境
torch1.2.0
Pillow8.2.0
torchvision0.4.0
CUDA版本可查看自己电脑,这里使用CUDA10.0
visual studio 2019
scipy1.2.1
numpy1.17.0
matplotlib3.1.2
opencv_python4.1.2.30
tqdm4.60.0
h5py2.10.0
安装
建议创建一个虚拟环境,本文使用到的是在Pycharm环境下
打开pytorch的官方安装方法:
https://pytorch.org/get-started/previous-versions/
但是可以先进入:
https://download.pytorch.org/whl/torch_stable.html
找到自己需要下载自己需要的即可。
找到自己的下载路径,然后再命令窗口定位,再使用
pip install +下载好的whl文件即可
再安装相关依赖包需要先激活环境,进行安装。
同时安装CUDA和visual studio 2019可参考网上教程,这里不细讲。
数据集的准备
本文使用VOC格式进行训练,
训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的Annotation中,文件格式为xml。
图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中,格式为jpg,如下图所示。
数据集处理
整个项目的文件如下(里面包含一些个人测试的代码):
第一步需要运行voc_annotation.py,并更改其代码里面的一些参数(annotation_mode、classes_path、trainval_percent、train_percent、VOCdevkit_path都可以修改,但也可以只修改以下内容即可):
需要修改model_data文件里面的voc_classes.txt内容,例如本例中修改如下:
即可生成训练用的2007_train.txt以及2007_val.txt。
图片处理
本例统一输入进来的图片是300*300大小的3通道图片。
对输入进来的图片进行判断是否为RGB,如果不是则进行转RGB
对图像进行统一大小裁剪,为防止图片失真,在其添加上灰条。
对图片进行数据增强,通过翻转,随机选取等操作。
模型训练
训练文件train.py中也要修改部分参数
classes_path一定要对应自己的分类文件,以及自己权重文件的位置。经过多次epochs后,权值会生成在logs文件夹。
在训练开始前还需要更改其他py文件的内容:
在summary.py文件中:
m=SSD300(7,‘vgg’).to(device)中7代表的是分类的个数,这里需要修改为2,因为只本例只分为了2类。
下面(3,300,300)代表输入的是300*300大小的3通道图片。
运行train.py文件进行模型训练,若出现out of memory问题,可以减小每次训练的batch_size的大小。
模型预测
模型预测先要去修改ssd.py文件中的model_path(在自己保存权值的logs文件当中选取一个权值文件,放到model_data文件夹中,并修改下面的路径,其次classes_path也要进行对应的修改:
这里单独调用摄像头进行预测,相关代码如下所示:
import time
import cv2
import numpy as np
from PIL import Image
from ssd import SSD
#口罩识别模型
if __name__ == "__main__":
ssd = SSD()
video_path = 0
video_save_path = ""
video_fps = 25.0
# 指定测量fps的时候,图片检测的次数
test_interval = 100
capture=cv2.VideoCapture(video_path)
if video_save_path!="":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
ref, frame = capture.read()
if not ref:
raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
fps = 0.0
while(True):
t1 = time.time()
# 读取某一帧
ref, frame = capture.read()
if not ref:
break
# 格式转变,BGRtoRGB
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(ssd.detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
fps = ( fps + (1./(time.time()-t1)) ) / 2
print("fps= %.2f"%(fps))
frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video",frame)
if video_save_path!="":
out.write(frame)
if cv2.waitKey(10) & 0xff==ord('q'):
break
capture.release()
cv2.destroyAllWindows()
效果图如下
未戴口罩
戴口罩
整体来说效果还是不错的。
后续
后面我又去找了其他数据集进行训练,对其进行不同的图片处理以及模型的改进,达到的效果还不错。但是图片格式为jpeg的,因此在代码当中添加了对图片类型的判断,但是若不添加代码,则需要更改文件
get_map.py中:
后缀为对应的图片类型,还有在voc_annotation.py代码中有一处也需要修改图片后缀名。
其次自己写了一个简易版的GUI界面,使其输出各坐标,以及害虫的分类
效果图如下:
但在模型对小目标检测方面还是存在一点问题,正在尝试提高其精度。
建议还是要先去学习下SSD模型的基本算法思路,理解起来更加清楚、明白.
来源:https://blog.csdn.net/Rebacca122222/article/details/124578323
猜你喜欢
- 首先,创建一个存储过程 get_clob:t_name:要查询的表名;f_name:要查询的字段名;u_id:表的主键,查询条件;l_pos
- 本文实例讲述了mysql中left join设置条件在on与where时的用法区别。分享给大家供大家参考,具体如下:一、首先我们准备两张表来
- MMClassification是一个基于PyTorch的开源图像分类工具箱,是OpenMMLab项目的一部分,源码传送门,最新发布版本为v
- 前言由于后端使用php、node.js、java等进行大量的图片下载操作可能会导致服务器负载过高,所以将图片下载转移到客户端是个不错的选择,
- python的try语句有两种风格一是处理异常(try/except/else)二是无论是否发生异常都将执行最后的代码(try/finall
- 一、校验数字的表达式数字:^[0-9]*$n位的数字:^\d{n}$至少n位的数字:^\d{n,}$m-n位的数字:^\d{m,n}$零和非
- 我们在建立一个大型网站的时候会有很多副页面框架模式,甚至一些细节元素都是相同的。但令人困扰的是更新它们却要费些周折,要一遍遍地反复更新每个页
- 思路:利用栈实现代数式中括号有效行的的检验:代码:class mychain(object): #利用链表建立栈,链表为父类 length=
- 1,ajax(asynchronouse javascript and xml)异步的 javascrip 和xml 2,(包含了7种技术:
- python数组进行降维在深度学习训练过程中,我们有时候想要输出图片看看图片长什么样,但是训练时的图片格式一般都会多出一个批次的维度,如[1
- 一、函数入门1.概念函数是可以重复执行一定任务的代码片段,具有独立的固定的输入输出接口。函数定义的本质,是给一段代码取个名字,方便以后重复使
- 前言mysql是高版本,当执行group by时,select的字段不属于group by的字段的话,sql语句就会报错。错误提示:this
- python 包含子目录中的模块方法比较简单,关键是能够在sys.path里面找到通向模块文件的路径。下面将具体介绍几种常用情况: (1)主
- 前言在当前的JavaScript中,并没有枚举这个概念,在某些场景中使用枚举更能保证数据的正确性,减少数据校验过程,下面就介绍一下JavaS
- 在CMD控制台进入Jupyter notebook之前,先激活安装了该模块的配置环境,再启动jupyter notebook,问题完美解决。
- //获取元素的样式值。 function getStyle(elem,name){ if(elem.style[name]){ return
- 定义计算N的阶乘的函数1)使用循环计算阶乘def frac(n): r = 1 if n<=1:
- <?php session_start(); $_SESSION['username']="zhuzhao&
- 日志文件对于一个服务器来说是非常重要的,它记录着服务器的运行信息,许多操作都会写日到日志文件,通过日志文件可以监视服务器的运行状态及查看服务
- 写桌面程序或有些特殊操作的,经常需要访问剪切板。python有专用的模块,可以很方便简单的操作剪切板如下:#coding:utf-8impo