python使用tensorflow保存、加载和使用模型的方法

作者:LordofRobots 时间:2021-01-25 13:19:26 

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:


#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut1_save.py
#Author: Wang  
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 11:04:25
############################

import tensorflow as tf

# prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2')
b1 = tf.Variable(2.0, name = 'bias1')
feed_dict = {w1:[10,3], w2:[5,5]}

# define a test operation that will be restored
w3 = tf.add(w1, w2) # without name, w3 will not be stored
w4 = tf.multiply(w3, b1, name = "op_to_restore")

#saver = tf.train.Saver()
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print sess.run(w4, feed_dict)
#saver.save(sess, 'my_test_model', global_step = 100)
saver.save(sess, 'my_test_model')
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:


#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut2_import.py
#Author: Wang  
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:16:38
############################  
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my_test_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:


#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut3_reuse.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:33:35
############################

import tensorflow as tf

sess = tf.Session()

# First, load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# Second, access and create placeholders variables and create feed_dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1], w2:[4,6]}

# Access the op that want to run
op_to_restore = graph.get_tensor_by_name('op_to_restore:0')

print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:


import tensorflow as tf
sess=tf.Session()  
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))

# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):


......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning  

#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')

#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()

new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

来源:http://blog.csdn.net/LordofRobots/article/details/77719020

标签:python,tensorflow
0
投稿

猜你喜欢

  • 使用Kubernetes集群环境部署MySQL数据库的实战记录

    2024-01-14 15:30:16
  • Python后台开发Django会话控制的实现

    2022-11-09 22:29:16
  • 自动定时备份sqlserver数据库的方法

    2024-01-13 20:45:14
  • Python机器学习NLP自然语言处理基本操作词袋模型

    2023-08-20 06:23:30
  • 分享一些可视信息设计资源

    2009-10-06 15:19:00
  • 详解Python中break语句的用法

    2021-12-21 22:18:17
  • python调用win32接口进行截图的示例

    2021-07-22 07:19:45
  • pytorch 实现查看网络中的参数

    2023-10-28 22:08:37
  • python for循环remove同一个list过程解析

    2023-03-20 22:07:48
  • python+selenium 定位到元素,无法点击的解决方法

    2022-02-01 12:29:52
  • python中使用enumerate函数遍历元素实例

    2021-05-08 04:56:41
  • Python爬虫入门案例之回车桌面壁纸网美女图片采集

    2022-12-25 19:40:57
  • 进制转换算法原理(二进制 八进制 十进制 十六进制)

    2022-01-09 03:18:23
  • Oracle跨数据库查询并插入实现原理及代码

    2024-01-14 18:52:58
  • 详解Appium+Python之生成html测试报告

    2022-12-21 22:38:51
  • python实现UDP协议下的文件传输

    2023-10-10 10:26:20
  • mysql 模糊查询 concat()的用法详解

    2024-01-14 01:48:46
  • django中forms组件的使用与注意

    2021-03-11 00:14:04
  • 关于python爬虫应用urllib库作用分析

    2023-11-02 12:59:43
  • 利用Python实现绘制3D爱心的代码分享

    2021-03-30 23:25:12
  • asp之家 网络编程 m.aspxhome.com