pytorch制作自己的LMDB数据操作示例

作者:团长sama 时间:2023-05-24 11:51:27 

本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签


import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
if imageBin is None:
 return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
 return False
return True
def writeCache(env, cache):
with env.begin(write=True) as txn:
 for k, v in cache.items():
  txn.put(k.encode(), v)
def _is_difficult(word):
assert isinstance(word, str)
return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
  outputPath  : LMDB output path
  imagePathList : list of image path
  labelList   : list of corresponding groundtruth texts
  lexiconList  : (optional) list of lexicon lists
  checkValid  : if true, check the validity of every image
"""
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
cache = {}
cnt = 1
for i in range(nSamples):
 imagePath = imagePathList[i]
 label = labelList[i]
 if len(label) == 0:
  continue
 if not os.path.exists(imagePath):
  print('%s does not exist' % imagePath)
  continue
 with open(imagePath, 'rb') as f:
  imageBin = f.read()
 if checkValid:
  if not checkImageIsValid(imageBin):
   print('%s is not a valid image' % imagePath)
   continue
 #数据库中都是二进制数据
 imageKey = 'image-%09d' % cnt#9位数不足填零
 labelKey = 'label-%09d' % cnt
 cache[imageKey] = imageBin
 cache[labelKey] = label.encode()
 if lexiconList:
  lexiconKey = 'lexicon-%09d' % cnt
  cache[lexiconKey] = ' '.join(lexiconList[i])
 if cnt % 1000 == 0:
  writeCache(env, cache)
  cache = {}
  print('Written %d / %d' % (cnt, nSamples))
 cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples).encode()
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
 with open(txt_path,'r') as fr:
   jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
 txt_content_list=[]
 for jpg in jpg_list:
   label_path=jpg.replace('.jpg','.txt')
   with open(label_path,'r') as fr:
     try:
       str_tmp=fr.readline()
     except UnicodeDecodeError as e:
       print(label_path)
       raise(e)
     txt_content_list.append(str_tmp.strip())
 return jpg_list,txt_content_list
if __name__ == "__main__":
txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
imagePathList,labelList=get_sample_list(txt_path)
createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__


from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
import moxing as mox
#moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
def __init__(self, root, voc_type, max_len, num_samples, transform=None):
 super(LmdbDataset, self).__init__()
 if global_args.run_on_remote:
  dataset_name = os.path.basename(root)
  data_cache_url = "/cache/%s" % dataset_name
  if not os.path.exists(data_cache_url):
   os.makedirs(data_cache_url)
  if mox.file.exists(root):
   mox.file.copy_parallel(root, data_cache_url)
  else:
   raise ValueError("%s not exists!" % root)
  self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
 else:
  self.env = lmdb.open(root, max_readers=32, readonly=True)
 assert self.env is not None, "cannot create lmdb from %s" % root
 self.txn = self.env.begin()
 self.voc_type = voc_type
 self.transform = transform
 self.max_len = max_len
 self.nSamples = int(self.txn.get(b"num-samples"))
 self.nSamples = min(self.nSamples, num_samples)
 assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
 self.EOS = 'EOS'
 self.PADDING = 'PADDING'
 self.UNKNOWN = 'UNKNOWN'
 self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
 self.char2id = dict(zip(self.voc, range(len(self.voc))))
 self.id2char = dict(zip(range(len(self.voc)), self.voc))
 self.rec_num_classes = len(self.voc)
 self.lowercase = (voc_type == 'LOWERCASE')
def __len__(self):
 return self.nSamples
def __getitem__(self, index):
 assert index <= len(self), 'index range error'
 index += 1
 img_key = b'image-%09d' % index
 imgbuf = self.txn.get(img_key)
 #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
 buf = six.BytesIO()
 buf.write(imgbuf)
 buf.seek(0)
 try:
  img = Image.open(buf).convert('RGB')
  # img = Image.open(buf).convert('L')
  # img = img.convert('RGB')
 except IOError:
  print('Corrupted image for %d' % index)
  return self[index + 1]
 # reconition labels
 label_key = b'label-%09d' % index
 word = self.txn.get(label_key).decode()
 if self.lowercase:
  word = word.lower()
 ## fill with the padding token
 label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
 label_list = []
 for char in word:
  if char in self.char2id:
   label_list.append(self.char2id[char])
  else:
   ## add the unknown token
   print('{0} is out of vocabulary.'.format(char))
   label_list.append(self.char2id[self.UNKNOWN])
 ## add a stop token
 label_list = label_list + [self.char2id[self.EOS]]
 assert len(label_list) <= self.max_len
 label[:len(label_list)] = np.array(label_list)
 if len(label) <= 0:
  return self[index + 1]
 # label length
 label_len = len(label_list)
 if self.transform is not None:
  img = self.transform(img)
 return img, label, label_len

希望本文所述对大家Python程序设计有所帮助。

来源:https://blog.csdn.net/sinat_24899403/article/details/102795355

标签:pytorch,LMDB,数据操作
0
投稿

猜你喜欢

  • 详解MySQL中存储函数创建与触发器设置

    2024-01-17 22:58:31
  • python爬虫教程之bs4解析和xpath解析详解

    2023-09-22 19:43:06
  • 三种禁用FileSystemObject组件的方法

    2007-09-23 15:52:00
  • MySQL表自增id溢出的故障复盘解决

    2024-01-24 05:00:50
  • Appium自动化测试中获取Toast信息操作

    2022-05-12 07:10:48
  • pyenv虚拟环境管理python多版本和软件库的方法

    2022-07-18 07:56:30
  • 使用pandas 将DataFrame转化成dict

    2022-08-11 17:46:33
  • 使用Matlab将矩阵保存到csv和txt文件

    2022-11-25 16:08:35
  • python写xml文件的操作实例

    2023-08-09 00:40:39
  • 教你利用Selenium+python自动化来解决pip使用异常

    2022-11-17 18:49:08
  • python中hashlib模块用法示例

    2023-03-20 12:20:13
  • Mybatis报错: org.apache.ibatis.exceptions.PersistenceException解决办法

    2024-01-18 18:58:03
  • AspJpeg 2.0组件使用教程(GIF篇)

    2008-12-16 19:37:00
  • Python实现基于KNN算法的笔迹识别功能详解

    2021-06-18 13:15:08
  • MySQL查看和修改时区的方法

    2024-01-15 05:42:33
  • ASP动态包含文件的改进方法

    2009-01-05 12:22:00
  • MySQL的数据类型和建库策略分析详解

    2024-01-14 11:33:30
  • sql 查询慢的原因分析

    2024-01-16 13:11:29
  • Python正则表达式中flags参数的实例详解

    2021-09-23 10:43:41
  • 使用动画实现微信读书的换一批效果(两种方式)

    2023-10-23 14:30:55
  • asp之家 网络编程 m.aspxhome.com