浅谈keras2 predict和fit_generator的坑

作者:BYR_jiandong 时间:2021-05-13 16:30:36 

1、使用predict时,必须设置batch_size,否则效率奇低。

查看keras文档中,predict函数原型:

predict(self, x, batch_size=32, verbose=0)

说明:

只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。

所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。

经验:

使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。

2、fit_generator

说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。

如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于

batch_size = 1,会性能非常低。

经验:

必须明确fit_generator参数steps_per_epoch

补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题

为什么要使用model.fit_generator?

在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。

fit_generator的定义如下:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

其中各项的具体解释,请参考Keras中文文档

我们重点关注的是generator参数:

generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:

一个 (inputs, targets) 元组

一个 (inputs, targets, sample_weights) 元组。

那么,问题来了,如何构建这个generator呢?有以下几种办法:

自己创建一个generator生成器

自己定义一个 Sequence (keras.utils.Sequence) 对象

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator

1.自己创建一个generator生成器

使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。

此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。

此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。

具体实现如下:


 def myGenerator(batch_size):
   # loading data
   X_train,Y_train=load_data(...)

# data processing
   # ................

total_size=X_train.size
   #batch_size means how many data you want to train one step

while 1:
     for i in range(total_size//batch_size):
       yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
 return myGenerator

接着你可以调用该生成器:

self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)

来源:https://blog.csdn.net/lujiandong1/article/details/73556163

标签:keras2,predict,fit,generator
0
投稿

猜你喜欢

  • pycharm中leetcode插件使用图文详解

    2022-09-19 19:19:43
  • 详解Python进阶之切片的误区与高级用法

    2022-09-18 04:03:12
  • Python中实现结构相似的函数调用方法

    2021-12-04 10:31:03
  • 使用pyecharts无法import Bar的解决方案

    2021-04-02 21:31:15
  • Golang利用自定义模板发送邮件的方法详解

    2023-06-29 07:07:16
  • 跟老齐学Python之再深点,更懂list

    2021-02-05 21:44:18
  • MySQL高级查询方法之记录查询

    2010-06-20 14:48:00
  • python实现自动重启本程序的方法

    2022-07-18 14:16:19
  • OpenCV模板匹配matchTemplate的实现

    2021-08-09 15:51:51
  • Python3 Tensorlfow:增加或者减小矩阵维度的实现

    2023-08-25 21:55:40
  • Python使用lambda表达式对字典排序操作示例

    2022-12-26 06:27:46
  • Python Requests模拟登录实现图书馆座位自动预约

    2022-01-31 00:25:46
  • Python实现上下文管理器的方法

    2021-06-22 17:31:15
  • 浅谈python中拼接路径os.path.join斜杠的问题

    2023-08-21 23:41:23
  • python学习之第三方包安装方法(两种方法)

    2021-02-20 03:29:40
  • Windows 2003服务器上传文件受限制的解决方法

    2011-02-14 11:29:00
  • 用PHP+java实现自动新闻滚动窗口

    2023-11-22 12:31:01
  • 恢复被删除的数据 Log Explorer for SQL Server 4.2 (一)

    2010-07-01 19:24:00
  • python sleep和wait对比总结

    2023-04-30 18:26:04
  • DWCS3-CSS布局之一CSS规则大纲

    2008-06-11 18:48:00
  • asp之家 网络编程 m.aspxhome.com