Keras构建神经网络踩坑(解决model.predict预测值全为0.0的问题)

作者:qq_42972774 时间:2023-03-28 04:05:22 

终于构建出了第一个神经网络,Keras真的很方便。

之前不知道Keras这么方便,在构建神经网络的过程中绕了很多弯路,最开始学的TensorFlow,后来才知道Keras。

TensorFlow和Keras的关系,就像c语言和python的关系,所以Keras是真的好用。

搞不清楚数据的标准化和归一化的关系,想对原始数据做归一化,却误把数据做了标准化,导致用model.predict预测出来的值全是0.0,在网上搜了好久但是没搜到答案,后来自己又把程序读了一遍,突然灵光一现好像是数据归一化出了问题,于是把数据预处理部分的标准化改成了归一化,修改过来之后才能正常预测出来值,才得到应有的数据趋势。

标准化:

(x-mean(x))/std(x) 这是使用z-score方法规范化

归一化:

(x-min(x))/(max(x)-min(x)) 这是常用的最小最大规范化方法

补充知识:keras加载已经训练好的模型文件,进行预测时却发现预测结果几乎为同一类(本人预测时几乎均为为第0类)**

原因:在进行keras训练时候,使用了keras内置的数据读取方式,但是在进行预测时候,使用了自定义的数据读取方式,本人为图片读取。

解决办法查看如下代码:


##############训练:
train_gen = ImageDataGenerator(rotation_range=10,
   width_shift_range=0.2,
   shear_range=0.2,
   zoom_range=0.2,
   fill_mode='constant',
   cval=0)
train_generator = train_gen.flow_from_directory(train_path,
     target_size=(224, 224),
     batch_size=16,
     class_mode='categorical',
     save_to_dir=train_g,
     save_prefix='man',
     save_format='jpg')

#############预测
img = cv2.imread(img_path)
img = cv2.resize(img, (row, col))
img = np.expands(img, axis=0)
out = model.predict(img)
# 上述方法是不行的,仔细查看keras内置读取方式,可以观察到内置了load_img方式
# 因此,我们在预测时候,将读取图片的方式改为
from keras.preprocessing.image import load_img, img_to_array
img = load_img(img_path)
img = img_to_array(img, target_size=(row, col))
img = np.expands(img, axis=0)
out = model.predict(img)

注:本文意在说明 对训练数据和预测数据的读取、预处理方式上应该在某种程度上保持一致,从而避免训练结果和真实预测结果相差过大的情况。

来源:https://blog.csdn.net/qq_42972774/article/details/105101935

标签:Keras,model,predict,预测值
0
投稿

猜你喜欢

  • 使用Python中的greenlet包实现并发编程的入门教程

    2023-10-18 08:29:00
  • Windows下安装python MySQLdb遇到的问题及解决方法

    2022-07-20 13:22:36
  • 扫盲大讲堂:mysql出错的代码解析及解答

    2009-09-05 10:08:00
  • 去掉运行JavaScript时IE产生的警告栏

    2008-09-11 18:07:00
  • Django 报错:Broken pipe from ('127.0.0.1', 58924)的解决

    2021-03-27 21:12:09
  • 浅谈python正则的常用方法 覆盖范围70%以上

    2022-05-18 21:01:13
  • 一个ASP记录集分页显示的例子

    2007-09-14 10:57:00
  • 网站中文字的视觉设计

    2008-04-16 13:35:00
  • 牛刀小试YUI compressor(YUI安装方法)

    2009-02-12 16:18:00
  • Python 多核并行计算的示例代码

    2022-08-18 11:20:36
  • python3实现读取chrome浏览器cookie

    2023-10-18 13:18:44
  • php遍历目录方法小结

    2023-11-17 12:49:40
  • Laravel实现批量更新多条数据

    2023-10-23 03:23:03
  • Python实现列表拼接和去重的三种方式

    2021-05-02 23:43:54
  • go sync Once实现原理示例解析

    2023-07-01 12:21:13
  • Python如何安装第三方模块

    2023-08-01 12:50:07
  • CSS3的五个使用技巧[译]

    2009-02-19 13:01:00
  • 在ASP中改善动态分页的性能

    2008-05-08 14:27:00
  • Python两个整数相除得到浮点数值的方法

    2021-04-17 10:39:54
  • 获取Dom元素的X/Y坐标

    2009-10-10 12:49:00
  • asp之家 网络编程 m.aspxhome.com