Tensorflow2.1 MNIST图像分类实现思路分析
作者:我是王大你是谁 发布时间:2023-04-17 03:35:32
前言
之前工作中主要使用的是 Tensorflow 1.15 版本,但是渐渐跟不上工作中的项目需求了,而且因为 2.x 版本和 1.x 版本差异较大,所以要专门花时间学习一下 2.x 版本,本文作为学习 Tensorflow 2.x 版本的开篇,主要介绍了使用 cpu 版本的 Tensorflow 2.1 搭建深度学习模型,完成对于 MNIST 数据的图片分类的任务。
主要思路和实现
(1) 加载数据,处理数据
这里是要导入 tensorflow 的包,前提是你要提前安装 tensorflow ,我这里为了方便直接使用的是 cpu 版本的 tensorflow==2.1.0 ,如果是为了学习的话,cpu 版本的也够用了,毕竟数据量和模型都不大。
import tensorflow as tf
这里是为了加载 mnist 数据集,mnist 数据集里面就是 0-9 这 10 个数字的图片集,我们要使用深度学习实现一个模型完成对 mnist 数据集进行分类的任务,这个项目相当于 java 中 hello world 。
mnist = tf.keras.datasets.mnist
这里的 (x_train, y_train) 表示的是训练集的图片和标签,(x_test, y_test) 表示的是测试集的图片和标签。
(x_train, y_train), (x_test, y_test) = mnist.load_data()
每张图片是 28*28 个像素点(数字)组成的,而每个像素点(数字)都是 0-255 中的某个数字,我们对其都除 255 ,这样就是相当于对这些图片的像素点值做归一化,这样有利于模型加速收敛,在本项目中执行本操作比不执行本操作最后的准确率高很多,在文末会展示注释本行情况下,模型评估的指标结果,大家可以自行对比差异。
x_train, x_test = x_train / 255.0, x_test / 255.0
(2) 使用 keras 搭建深度学习模型
这里主要是要构建机器学习模型,模型分为以下几层:
第一层要接收图片的输入,每张图片是 28*28 个像素点组成的,所以 input_shape=(28, 28)
第二层是一个输出 128 维度的全连接操作
第三层是要对第二层的输出随机丢弃 20% 的 Dropout 操作,这样有利于模型的泛化
第四层是一个输出 10 维度的全连接操作,也就是预测该图片分别属于这十种类型的概率
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
(3) 定义损失函数
这里主要是定义损失函数,这里的损失函数使用到了 SparseCategoricalCrossentropy ,主要是为了计算标签和预测结果之间的交叉熵损失。
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
(4) 配置编译模型
这里主要是配置和编译模型,优化器使用了 adam ,要优化的评价指标选用了准确率 accuracy ,当然了还可以选择其他的优化器和评价指标。
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])
(5) 使用训练数据训练模型
这里主要使用训练数据的图片和标签来训练模型,将整个训练样本集训练 5 次。
model.fit(x_train, y_train, epochs=5)
训练过程结果输出如下:
Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 3s 43us/sample - loss: 0.2949 - accuracy: 0.9144
Epoch 2/5
60000/60000 [==============================] - 2s 40us/sample - loss: 0.1434 - accuracy: 0.9574
Epoch 3/5
60000/60000 [==============================] - 2s 36us/sample - loss: 0.1060 - accuracy: 0.9676
Epoch 4/5
60000/60000 [==============================] - 2s 31us/sample - loss: 0.0891 - accuracy: 0.9721
Epoch 5/5
60000/60000 [==============================] - 2s 29us/sample - loss: 0.0740 - accuracy: 0.9771
10000/10000 - 0s - loss: 0.0744 - accuracy: 0.9777
(6) 使用测试数据评估模型
这里主要是使用测试数据中的图片和标签来评估模型,verbose 可以选为 0、1、2 ,区别主要是结果输出的形式不一样,嫌麻烦可以不设置
model.evaluate(x_test, y_test, verbose=2)
评估的损失值和准确率如下:
[0.07444974237508141, 0.9777]
(7) 展示不使用归一化的操作的训练和评估结果
在不使用归一化操作的情况下,训练过程输出如下:
Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 3s 42us/sample - loss: 2.4383 - accuracy: 0.7449
Epoch 2/5
60000/60000 [==============================] - 2s 40us/sample - loss: 0.5852 - accuracy: 0.8432
Epoch 3/5
60000/60000 [==============================] - 2s 36us/sample - loss: 0.4770 - accuracy: 0.8724
Epoch 4/5
60000/60000 [==============================] - 2s 34us/sample - loss: 0.4069 - accuracy: 0.8950
Epoch 5/5
60000/60000 [==============================] - 2s 32us/sample - loss: 0.3897 - accuracy: 0.8996
10000/10000 - 0s - loss: 0.2898 - accuracy: 0.9285
评估结果输入如下:
[0.2897613683119416, 0.9285]
所以我们通过和上面的进行对比发现,不进行归一化操作,在训练过程中收敛较慢,在相同 epoch 的训练之后,评估的准确率和损失值都不理想,损失值比第(6)步操作的损失值大,准确率比第(6)步操作低 5% 左右。
来源:https://juejin.cn/post/7164678763903451172


猜你喜欢
- 加载垃圾邮件数据集spambase.csv(数据集基本信息:样本数: 4601,特征数量: 57, 类别:1 为垃圾邮件,0 为
- 本文实例讲述了Python使用add_subplot与subplot画子图操作。分享给大家供大家参考,具体如下:子图:就是在一张figure
- dom元素内部内容是动态的,重置数据后直接获取宽高总是不准确:this.$refs.editor[0].offsetHeight;原因:重置
- 整理了一下python 中文件的输入输出及主要介绍一些os模块中对文件系统的操作。文件输入输出1、内建函数open(file_name,文件
- 前言很多人会使用postman工具,或者熟悉python,但不一定会使用python来编写测试用例脚本,postman里面可以完整的将pyt
- el-col-group"el-col-group" 是一个 Vue.js 函数式组件,允许您在 "el-ta
- 1 丰富的二维动画/图形和视音频表现 Rich 2D animation/graphics with audio and video这点毋庸
- 1.auto close tagHTML自动补全标签2.beautiful UI32个主题集合,具体使用看个人喜好。3.better com
- 跟小组里一自称小方方的卖萌90小青年聊天,IT男的坏习惯,聊着聊着就扯到技术上去了,小方方突然问 1、声明一个数值类型的变量我看到三种,区别
- 一、SQLAlchemy 介绍1.1 ORM 的概念ORM全称Object Relational Mapping(对象关系映射),通过 OR
- JavaScript网页–跨年倒计时,供大家参考,具体内容如下最近学弟在追一个学妹,我在帮学弟出谋划策。学妹告诉学弟,我怕我们之间是因为这段
- 在python中利用numpy array进行数据处理,经常需要找出符合某些要求的数据位置,有时候还需要对这些位置重新赋值。这里总结了几种找
- create database MyDb on ( name=mainDb, filename='c:\MyDb\mainDb.md
- 第一种--对象键值去重Array.prototype.unique1 = function () { var r
- 前言风玫瑰是由气象学家用于给出如何风速和风向在特定位置通常分布的简明视图的图形工具。它也可以用来描述空气质量污染源。风玫瑰工具使用Matpl
- pycharm指定python路径,pycharm配置python环境的方法是:1、依次点击【File】、【Project Interpre
- 前言相信大家在日常的web开发中,作为前端经常会遇到处理图片拉伸问题的情况。例如banner、图文列表、头像等所有和用户或客户自主操作图片上
- 一、技术背景损失函数是机器学习中直接决定训练结果好坏的一个模块,该函数用于定义计算出来的结果或者是神经网络给出的推测结论与正确结果的偏差程度
- 日期的转换及计算对于日期,有时需执行不同时间单位的转换,或者接受字符串格式的日期,转换为 datetime 对象。有时需计算日期的范围,以及
- 学校让我们在放假期间自觉Python,对于Python我是小白的不能再小白了。一切从头开始,找学习资料,看视频教程光看书看视频也不行还要自己