pytorch版本PSEnet训练并部署方式
作者:__JDM__ 发布时间:2021-01-06 09:41:18
标签:pytorch,PSEnet,训练,部署
概述
源码地址
torch版本
训练环境没有按照torch的readme一样的环境,自己部署环境为:
torch==1.9.1
torchvision==0.10.1
python==3.8.0
cuda==10.2
mmcv==0.2.12
editdistance==0.5.3
Polygon3==3.0.9.1
pyclipper==1.3.0
opencv-python==3.4.2.17
Cython==0.29.24
./compile.sh
制作数据集
1、训练的数据集
采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。
转换代码:
import os
from lxml import etree
import numpy as np
import math
src_xml = "ANN"
txt_dir = "gt"
xml_listdir = os.listdir(src_xml)
xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
def xml_out(xml_path):
gt_lines = []
ET = etree.parse(xml_path)
objs = ET.findall("object")
for ix,obj in enumerate(objs):
name = obj.find("name").text
robox = obj.find("robndbox")
cx = int(float(robox.find("cx").text))
cy = int(float(robox.find("cy").text))
w = int(float(robox.find("w").text))
h = int(float(robox.find("h").text))
angle = float(robox.find("angle").text)
# angle = math.degrees(angle1)
wx1 = cx - int(0.5 * w)
wy1 = cy - int(0.5 * h)
wx2 = cx + int(0.5 * w)
wy2 = cy - int(0.5 * h)
wx3 = cx - int(0.5 * w)
wy3 = cy + int(0.5 * h)
wx4 = cx + int(0.5 * w)
wy4 = cy + int(0.5 * h)
x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
gt_lines.append(lines)
return gt_lines
def main():
count = 0
for xml_dir in xml_listdir:
gt_lines = xml_out(os.path.join(src_xml,xml_dir))
txt_path = "gt_" + xml_dir[:-4] + ".txt"
with open(os.path.join(txt_dir,txt_path),"a+") as fd:
fd.writelines(gt_lines)
count +=1
print("Write file %s" % str(count))
if __name__ == "__main__":
main()
rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。
转换后的格式为x1,y1,x2,y2,x3,y3,x4,y4,"classes"
,此处classes为检测的类别,如果是模糊训练的话,classes为“###”。
但是重点,这个源代码对于模糊训练,loss一直为1。
2、将数据集分成训练集和测试集
这里可以按照源码路径存放数据集,也可以修改源码存放位置。
PSENet-python3\dataset\psenet\psenet_ic15.py
修改下述代码为自己文件夹
3、训练
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py
其中根据源码中的readme,
可以根据自己的需要,自行选择配置文件。
4、部署测试
import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
"""Do image preprocessing before prediction on any data.
:param image: original image
:param target_size: target image size
:return:
preprocessed image
"""
#assert os.path.exists(img), 'file is not exists'
#img = cv2.imread(img)
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# h, w = image.shape[:2]
# scale = long_size / max(h, w)
img = cv2.resize(img, target_size)
# 将图片由(w,h)变为(1,img_channel,h,w)
tensor = transforms.ToTensor()(img)
tensor = tensor.unsqueeze_(0)
tensor = tensor.to(torch.device("cuda:0"))
return tensor
def report_speed(outputs, speed_meters):
total_time = 0
for key in outputs:
if 'time' in key:
total_time += outputs[key]
speed_meters[key].update(outputs[key])
print('%s: %.4f' % (key, speed_meters[key].avg))
speed_meters['total_time'].update(total_time)
print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
def load_model(cfg):
model = build_model(cfg.model)
model = model.cuda()
model.eval()
checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
if checkpoint is not None:
if os.path.isfile(checkpoint):
print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
sys.stdout.flush()
checkpoint = torch.load(checkpoint)
d = dict()
for key, value in checkpoint['state_dict'].items():
tmp = key[7:]
d[tmp] = value
model.load_state_dict(d)
else:
print("No checkpoint found at")
raise
# fuse conv and bn
model = fuse_module(model)
return model
if __name__ == '__main__':
src_dir = "testimg/"
save_dir = "test_save/"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
for d in [cfg, cfg.data.test]:
d.update(dict(
report_speed=False
))
if cfg.report_speed:
speed_meters = dict(
backbone_time=AverageMeter(500),
neck_time=AverageMeter(500),
det_head_time=AverageMeter(500),
det_pse_time=AverageMeter(500),
rec_time=AverageMeter(500),
total_time=AverageMeter(500)
)
model = load_model(cfg)
model.eval()
count = 0
for img_name in os.listdir(src_dir):
img = cv2.imread(src_dir + img_name)
tensor = prepare_image(img, target_size=(1376, 1024))
data = dict()
img_metas = dict()
data['imgs'] = tensor
img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
img_metas['img_size'] = torch.tensor([[1376, 1024]])
data['img_metas'] = img_metas
data.update(dict(
cfg=cfg
))
with torch.no_grad():
outputs = model(**data)
if cfg.report_speed:
report_speed(outputs, speed_meters)
for bboxes in outputs['bboxes']:
x1 = bboxes[0]
y1 = bboxes[1]
x2 = bboxes[4]
y2 = bboxes[5]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
count = count + 1
cv2.imwrite(save_dir + img_name, img)
print("img test:", count)
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
训练代码里含有。
来源:https://blog.csdn.net/WSNjiang/article/details/120821227


猜你喜欢
- 为解决在Vue组件中全局引入 scss 变量及 mixins ,装载了一个名为 "sass-resources-loader&qu
- 第1步:安装cross-envnpm i --save-dev cross-env 第2步:修改各环境下的参数在config/目录下添加te
- Keras的模型是用hdf5存储的,如果想要查看模型,keras提供了get_weights的函数可以查看:for layer in mod
- 本文实例为大家分享了tensorflow如何批量读取图片的具体代码,供大家参考,具体内容如下代码:import tensorflow as
- 我们在做网站的时候,网站后台系统一般都会用到web编辑器,今天笔者就给大家推荐一款百度UEditor编辑器。关于这款百度UEditor编辑器
- 内容摘要:在像网站首页这样的资源比较集中的页面中,那些栏目最经常被用户点击?居左居右对广告的点击率的影响是什么?“一切用数字说话”:以上问题
- 手机控件查看工具uiautomatorviewer工具简介用来扫描和分析Android应用程序的UI控件的工具.如何使用 1.进入
- 本文所述的Python实现冒泡,插入,选择排序简单实例比较适合Python初学者从基础开始学习数据结构和算法,示例简单易懂,具体代码如下:#
- 都知道最近ChatGPT聊天机器人爆火,我也想方设法注册了账号,据说后面要收费了。ChatGPT是一种基于大语言模型的生成式AI,换句话说它
- 如题在SQL中 SELECT ... FROM ... ORDER BY abc ASC; 如果abc是字符串,那么结果会按照a-z 中文按
- 发一个数字拼图游戏,有点小疑问前几天写得,其中一段代码还要感谢“簡簡單單愛妳”的提示,不过我还是不太明白, ,有点笨。 $(&qu
- 说明本文根据https://github.com/liuchengxu/blockchain-tutorial的内容,用python实现的,
- 1、什么是路由懒加载官方的解释:当打包构建应用时,JavaScript 包会变得非常大,影响页面加载。如果我们能把不同路由对应的组件分割成不
- 本文实现了PyQt5个各种弹出窗口:输入框、消息框、文件对话框、颜色对话框、字体对话框、自定义对话框其中,为了实现自定义对话框的返回值,使用
- 直接调用系统的颜色显示在网页上本来是件很好玩滴事,但是,也有个缺点,就是可用的色太少 比如Bindows在它的启动画面一点点应用。=。= 上
- 1:strip()方法去除字符串开头或者结尾的空格>>> a = " a b c ">>&
- 在用tensorflow做一维的卷积神经网络的时候会遇到tf.nn.conv1d和layers.conv1d这两个函数,但是这两个函数有什么
- 总有一些程序在windows平台表现不稳定,动不动一段时间就无响应,但又不得不用,每次都是发现问题了手动重启,现在写个脚本定时检测进程是否正
- 下面是最终代码 (windows下实现的) # -*- coding: cp936 -*- import os path = 'D:
- 一、什么是Python类?python中的类是创建特定对象的蓝图。它使您可以以特定方式构建软件。问题来了,怎么办?类允许我们以一种易于重用的