python神经网络tensorflow利用训练好的模型进行预测

作者:Bubbliiiing 时间:2022-09-27 17:33:17 

学习前言

在神经网络学习中slim常用函数与如何训练、保存模型文章里已经讲述了如何使用slim训练出来一个模型,这篇文章将会讲述如何预测。

载入模型思路

载入模型的过程主要分为以下四步:

1、建立会话Session;

2、将img_input的placeholder传入网络,建立网络结构;

3、初始化所有变量;

4、利用saver对象restore载入所有参数。

这里要注意的重点是,在利用saver对象restore载入所有参数之前,必须要建立网络结构,因为网络结构对应着cpkt文件中的参数。

(网络层具有对应的名称scope。)

python神经网络tensorflow利用训练好的模型进行预测

实现代码

在运行实验代码前,可以直接下载代码,因为存在许多依赖的文件

import tensorflow as tf
import numpy as np
from nets import Net
from tensorflow.examples.tutorials.mnist import input_data
def compute_accuracy(x_data,y_data):
   global prediction
   y_pre = sess.run(prediction,feed_dict={img_input:x_data})
   correct_prediction = tf.equal(tf.arg_max(y_data,1),tf.arg_max(y_pre,1))
   accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
   result = sess.run(accuracy,feed_dict = {img_input:x_data})
   return result
mnist = input_data.read_data_sets("MNIST_data",one_hot = "true")
slim = tf.contrib.slim
# img_input的placeholder
img_input = tf.placeholder(tf.float32, shape = (None, 784))
img_reshape = tf.reshape(img_input,shape = (-1,28,28,1))
# 载入模型
sess = tf.Session()
Conv_Net = Net.Conv_Net()
# 将img_input的placeholder传入网络
prediction = Conv_Net.net(img_reshape)
# 载入模型
ckpt_filename = './logs/model.ckpt-20000'
# 初始化所有变量
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
# 恢复
saver.restore(sess, ckpt_filename)
print(compute_accuracy(mnist.test.images,mnist.test.labels))

运行结果为:

0.9921

来源:https://blog.csdn.net/weixin_44791964/article/details/102584474

标签:python,神经网络,tensorflow,模型预测,训练好的模型
0
投稿

猜你喜欢

  • javascript 用函数语句和表达式定义函数的区别详解

    2024-04-16 09:06:26
  • MySQL之批量插入的4种方案总结

    2024-01-19 16:13:11
  • python入门学习之自带help功能初步使用示例

    2021-05-27 17:07:28
  • Golang 拷贝Array或Slice的操作

    2024-04-30 10:02:40
  • 关于Mysql中文乱码问题该如何解决(乱码问题完美解决方案)

    2024-01-13 20:06:12
  • HTML5 Canvas 起步(2) - 路径

    2009-05-12 12:06:00
  • 通过模版字符串及JSON数据进行目标内容整理的一个小方法

    2010-01-12 16:55:00
  • python pickle存储、读取大数据量列表、字典数据的方法

    2021-10-01 11:22:51
  • 用yum安装MySQLdb模块的步骤方法

    2024-01-12 18:23:25
  • python 实现简单的FTP程序

    2021-03-29 10:33:00
  • python导入坐标点的具体操作

    2023-02-24 19:59:46
  • 浅析Go设计模式之Facade(外观)模式

    2023-07-16 19:23:08
  • Jquery.TreeView结合ASP.Net和数据库生成菜单导航条

    2024-01-15 01:23:44
  • 使用Python生成url短链接的方法

    2021-05-05 02:55:12
  • python清空命令行方式

    2023-12-08 09:50:35
  • Python检测PE所启用保护方式详解

    2022-03-11 12:36:08
  • Javascript函数类型判断解决方案

    2009-08-27 15:32:00
  • 为什么Python中没有"a++"这种写法

    2023-12-04 09:40:57
  • 详谈构造函数加括号与不加括号的区别

    2024-06-15 23:06:44
  • 详解Go语言设计模式之单例模式

    2024-03-26 13:53:37
  • asp之家 网络编程 m.aspxhome.com