pytorch版本PSEnet训练并部署方式

作者:__JDM__ 时间:2021-01-06 09:41:18 

概述

源码地址

torch版本

训练环境没有按照torch的readme一样的环境,自己部署环境为:

torch==1.9.1
torchvision==0.10.1
python==3.8.0
cuda==10.2
mmcv==0.2.12
editdistance==0.5.3
Polygon3==3.0.9.1
pyclipper==1.3.0
opencv-python==3.4.2.17
Cython==0.29.24
./compile.sh

制作数据集

1、训练的数据集

采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。

转换代码:

import os
from lxml import etree
import numpy as np
import math
src_xml = "ANN"
txt_dir = "gt"
xml_listdir = os.listdir(src_xml)
xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
def xml_out(xml_path):
   gt_lines = []
   ET = etree.parse(xml_path)
   objs = ET.findall("object")
   for ix,obj in enumerate(objs):
       name = obj.find("name").text
       robox = obj.find("robndbox")
       cx = int(float(robox.find("cx").text))
       cy = int(float(robox.find("cy").text))
       w = int(float(robox.find("w").text))
       h = int(float(robox.find("h").text))
       angle = float(robox.find("angle").text)
       # angle = math.degrees(angle1)
       wx1 = cx - int(0.5 * w)
       wy1 = cy - int(0.5 * h)
       wx2 = cx + int(0.5 * w)
       wy2 = cy - int(0.5 * h)
       wx3 = cx - int(0.5 * w)
       wy3 = cy + int(0.5 * h)
       wx4 = cx + int(0.5 * w)
       wy4 = cy + int(0.5 * h)
       x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
       y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
       x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
       y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
       x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
       y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
       x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
       y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
       lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
               str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
       gt_lines.append(lines)
       return gt_lines
def main():
   count = 0
   for xml_dir in xml_listdir:
       gt_lines = xml_out(os.path.join(src_xml,xml_dir))
       txt_path = "gt_" + xml_dir[:-4] + ".txt"
       with open(os.path.join(txt_dir,txt_path),"a+") as fd:
           fd.writelines(gt_lines)
       count +=1
       print("Write file %s" % str(count))
if __name__ == "__main__":
   main()

rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。

转换后的格式为x1,y1,x2,y2,x3,y3,x4,y4,"classes",此处classes为检测的类别,如果是模糊训练的话,classes为“###”。

但是重点,这个源代码对于模糊训练,loss一直为1。

2、将数据集分成训练集和测试集

pytorch版本PSEnet训练并部署方式

这里可以按照源码路径存放数据集,也可以修改源码存放位置。

PSENet-python3\dataset\psenet\psenet_ic15.py

修改下述代码为自己文件夹

pytorch版本PSEnet训练并部署方式

3、训练

CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py

其中根据源码中的readme,

pytorch版本PSEnet训练并部署方式

可以根据自己的需要,自行选择配置文件。

pytorch版本PSEnet训练并部署方式

4、部署测试

import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
   """Do image preprocessing before prediction on any data.
   :param image:       original image
   :param target_size: target image size
   :return:
                       preprocessed image
   """
   #assert os.path.exists(img), 'file is not exists'
   #img = cv2.imread(img)
   img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
   # h, w = image.shape[:2]
   # scale = long_size / max(h, w)
   img = cv2.resize(img, target_size)
   # 将图片由(w,h)变为(1,img_channel,h,w)
   tensor = transforms.ToTensor()(img)
   tensor = tensor.unsqueeze_(0)
   tensor = tensor.to(torch.device("cuda:0"))
   return tensor
def report_speed(outputs, speed_meters):
   total_time = 0
   for key in outputs:
       if 'time' in key:
           total_time += outputs[key]
           speed_meters[key].update(outputs[key])
           print('%s: %.4f' % (key, speed_meters[key].avg))
   speed_meters['total_time'].update(total_time)
   print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
def load_model(cfg):
   model = build_model(cfg.model)
   model = model.cuda()
   model.eval()
   checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
   if checkpoint is not None:
       if os.path.isfile(checkpoint):
           print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
           sys.stdout.flush()
           checkpoint = torch.load(checkpoint)
           d = dict()
           for key, value in checkpoint['state_dict'].items():
               tmp = key[7:]
               d[tmp] = value
           model.load_state_dict(d)
       else:
           print("No checkpoint found at")
           raise
       # fuse conv and bn
   model = fuse_module(model)
   return model
if __name__ == '__main__':
   src_dir = "testimg/"
   save_dir = "test_save/"
   if not os.path.exists(save_dir):
       os.makedirs(save_dir)
   cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
   for d in [cfg, cfg.data.test]:
       d.update(dict(
           report_speed=False
       ))
   if cfg.report_speed:
       speed_meters = dict(
           backbone_time=AverageMeter(500),
           neck_time=AverageMeter(500),
           det_head_time=AverageMeter(500),
           det_pse_time=AverageMeter(500),
           rec_time=AverageMeter(500),
           total_time=AverageMeter(500)
       )
   model = load_model(cfg)
   model.eval()
   count = 0
   for img_name in os.listdir(src_dir):
       img = cv2.imread(src_dir + img_name)
       tensor = prepare_image(img, target_size=(1376, 1024))
       data = dict()
       img_metas = dict()
       data['imgs'] = tensor
       img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
       img_metas['img_size'] = torch.tensor([[1376, 1024]])
       data['img_metas'] = img_metas
       data.update(dict(
           cfg=cfg
       ))
       with torch.no_grad():
           outputs = model(**data)
       if cfg.report_speed:
           report_speed(outputs, speed_meters)
       for bboxes in outputs['bboxes']:
           x1 = bboxes[0]
           y1 = bboxes[1]
           x2 = bboxes[4]
           y2 = bboxes[5]
           cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
       count = count + 1
       cv2.imwrite(save_dir + img_name, img)
       print("img test:", count)
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter

训练代码里含有。

来源:https://blog.csdn.net/WSNjiang/article/details/120821227

标签:pytorch,PSEnet,训练,部署
0
投稿

猜你喜欢

  • python下如何查询CS反恐精英的服务器信息

    2021-09-15 05:51:47
  • ASP JSON类文件的使用方法

    2011-04-30 16:39:00
  • 利用Python判断文件的几种方法及其优劣对比

    2022-07-03 06:20:02
  • Sql Server 2000内存调整

    2010-04-25 11:24:00
  • Python安装使用命令行交互模块pexpect的基础教程

    2023-07-09 22:43:40
  • asp.net 将一个图片以二进制值的形式存入Xml文件中的实例代码

    2023-07-23 13:31:30
  • python入门学习之自带help功能初步使用示例

    2021-05-27 17:07:28
  • Golang库插件注册加载机制的问题

    2023-06-24 04:25:59
  • Flash如何连接Mysql

    2010-11-11 11:57:00
  • 两个百度WEB面试题 怎么做?

    2010-09-03 18:40:00
  • 改进评论提交表单

    2009-03-25 20:37:00
  • Python代码部署的三种加密方案

    2022-03-22 02:24:40
  • python的简单web框架flask快速实现详解

    2023-03-10 08:26:36
  • python3 dict ndarray 存成json,并保留原数据精度的实例

    2021-03-04 13:25:31
  • 网页设计图标使用指南[译]

    2009-03-11 21:13:00
  • Pandas Matplotlib保存图形时坐标轴标签太长导致显示不全问题的解决

    2023-07-22 20:03:09
  • ASP四级连动下拉列表程序段

    2009-07-03 15:33:00
  • 利用Python抢回在蚂蚁森林逝去的能量(实现代码)

    2022-07-01 15:15:39
  • 导航与搜索合并的可能性

    2009-09-27 12:06:00
  • python文件和目录操作方法大全(含实例)

    2021-11-11 14:10:29
  • asp之家 网络编程 m.aspxhome.com