keras使用Sequence类调用大规模数据集进行训练的实现

作者:aszxs 时间:2021-01-03 20:24:35 

使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。

下面是我所使用的代码


class SequenceData(Sequence):
 def __init__(self, path, batch_size=32):
   self.path = path
   self.batch_size = batch_size
   f = open(path)
   self.datas = f.readlines()
   self.L = len(self.datas)
   self.index = random.sample(range(self.L), self.L)
 #返回长度,通过len(<你的实例>)调用
 def __len__(self):
   return self.L - self.batch_size
 #即通过索引获取a[0],a[1]这种
 def __getitem__(self, idx):
   batch_indexs = self.index[idx:(idx+self.batch_size)]
   batch_datas = [self.datas[k] for k in batch_indexs]
   img1s,img2s,audios,labels = self.data_generation(batch_datas)
   return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})

def data_generation(self, batch_datas):
   #预处理操作
   return img1s,img2s,audios,labels

然后在代码里通过fit_generation函数调用并训练

这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。


D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)),
         epochs=2, workers=20, #callbacks=[checkpoint],
         use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))

同样的,也可以在测试的时候使用

model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)

补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练

我就废话不多说了,大家还是直接看代码吧~


#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense

class DataGenerator(keras.utils.Sequence):

def __init__(self, datas, batch_size=1, shuffle=True):
   self.batch_size = batch_size
   self.datas = datas
   self.indexes = np.arange(len(self.datas))
   self.shuffle = shuffle

def __len__(self):
   #计算每一个epoch的迭代次数
   return math.ceil(len(self.datas) / float(self.batch_size))

def __getitem__(self, index):
   #生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
   # 生成batch_size个索引
   batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
   # 根据索引获取datas集合中的数据
   batch_datas = [self.datas[k] for k in batch_indexs]

# 生成数据
   X, y = self.data_generation(batch_datas)

return X, y

def on_epoch_end(self):
   #在每一次epoch结束是否需要进行一次随机,重新随机一下index
   if self.shuffle == True:
     np.random.shuffle(self.indexes)

def data_generation(self, batch_datas):
   images = []
   labels = []

# 生成数据
   for i, data in enumerate(batch_datas):
     #x_train数据
     image = cv2.imread(data)
     image = list(image)
     images.append(image)
     #y_train数据
     right = data.rfind("\\",0)
     left = data.rfind("\\",0,right)+1
     class_name = data[left:right]
     if class_name=="dog":
       labels.append([0,1])
     else:
       labels.append([1,0])
   #如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
   return np.array(images), np.array(labels)

# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = []
for file in os.listdir("D:/xxx"):
 file_path = os.path.join("D:/xxx", file)
 if os.path.isdir(file_path):
   class_num = class_num + 1
   for sub_file in os.listdir(file_path):
     train_datas.append(os.path.join(file_path, sub_file))

# 数据生成器
training_generator = DataGenerator(train_datas)

#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
      optimizer='sgd',
      metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)

来源:https://blog.csdn.net/qq_22033759/article/details/88798423

标签:keras,Sequence,数据集,训练
0
投稿

猜你喜欢

  • python网络爬虫实战

    2021-04-16 13:31:07
  • Python iter()函数用法实例分析

    2022-11-01 00:00:01
  • python3调用R的示例代码

    2021-05-01 21:53:59
  • 使用Python获取CPU、内存和硬盘等windowns系统信息的2个例子

    2023-08-26 23:12:32
  • Python中的jquery PyQuery库使用小结

    2023-05-27 11:08:15
  • 详解PHP中的mb_detect_encoding函数使用方法

    2023-11-14 19:48:45
  • Pytorch 解决自定义子Module .cuda() tensor失败的问题

    2023-11-19 15:01:57
  • Python将list中的string批量转化成int/float的方法

    2021-12-11 00:11:59
  • pycharm2020.1.2永久破解激活教程,实测有效

    2021-11-01 15:17:57
  • python selenium自动化测试框架搭建的方法步骤

    2023-05-24 21:38:49
  • 如何利用Python动态模拟太阳系运转

    2022-01-14 15:01:43
  • Windows 下更改 jupyterlab 默认启动位置的教程详解

    2023-06-11 13:10:12
  • anaconda安装pytorch1.7.1和torchvision0.8.2的方法(亲测可用)

    2021-01-13 03:03:38
  • ASP.NET中MD5和SHA1密码保护算法的使用

    2007-08-24 09:18:00
  • Python全景系列之模块与包全面解读

    2022-12-09 19:26:48
  • Python利用WMI实现ping命令的例子

    2022-07-12 04:42:22
  • ASP真正随机不重复查询代码

    2010-01-02 20:40:00
  • 在laravel中使用Symfony的Crawler组件分析HTML

    2023-11-17 18:54:07
  • Python实现的服务器示例小结【单进程、多进程、多线程、非阻塞式】

    2023-02-24 00:19:25
  • 好的Python培训机构应该具备哪些条件

    2022-06-22 14:52:57
  • asp之家 网络编程 m.aspxhome.com