TensorFlow 模型载入方法汇总(小结)

作者:叠加态的猫 时间:2022-11-09 00:05:42 

一、TensorFlow常规模型加载方法

保存模型

tf.train.Saver()类,.save(sess, ckpt文件目录)方法

参数名称功能说明默认值
var_listSaver中存储变量集合全局变量集合
reshape加载时是否恢复变量形状True
sharded是否将变量轮循放在所有设备上True
max_to_keep保留最近检查点个数5
restore_sequentially是否按顺序恢复变量,模型较大时顺序恢复内存消耗小True

var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。

如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。

加载模型

当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化

TensorFlow 模型载入方法汇总(小结)

checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:


ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt.model_checkpoint_path)


TensorFlow 模型载入方法汇总(小结) 


.meta文件保存了当前图结构


.index文件保存了当前参数名


.data文件保存了当前参数值


tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象




ckpt = tf.train.get_checkpoint_state('./model/')

tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量(『TensorFlow』滑动平均)


saver = tf.train.Saver({"v/ExponentialMovingAverage":v})

saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载


saver.restore(sess,'./model/model.ckpt-0')
saver.restore(sess,ckpt.model_checkpoint_path)

1.不加载图结构,只加载参数

由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。


'''
使用原网络保存的模型加载到自己重新定义的图上
可以使用python变量名加载模型,也可以使用节点名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf

IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'

with tf.Graph().as_default() as g:

x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
y = Net.inference_1(x, N_CLASS=5, train=False)

with tf.Session() as sess:
 # 程序前面得有 Variable 供 save or restore 才不报错
 # 否则会提示没有可保存的变量
 saver = tf.train.Saver()

ckpt = tf.train.get_checkpoint_state('./model/')
 img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
 img = sess.run(tf.expand_dims(tf.image.resize_images(
  tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))

if ckpt and ckpt.model_checkpoint_path:
  print(ckpt.model_checkpoint_path)
  saver.restore(sess,'./model/model.ckpt-0')
  global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  res = sess.run(y, feed_dict={x: img})
  print(global_step,sess.run(tf.argmax(res,1)))

2.加载图结构和参数


'''
直接使用使用保存好的图
无需加载python定义的结构,直接使用节点名称加载模型
由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错
现阶段不推荐使用,以后如果理解深入了可能会找到使用方法
'''
import AlexNet_train as train
import random
import tensorflow as tf

IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'

ckpt = tf.train.get_checkpoint_state('./model/')       # 通过检查点文件锁定最新的模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # 载入图结构,保存在.meta文件中

with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path)      # 载入参数,参数保存在两个文件中,不过restore会自己寻找

img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
img = sess.run(tf.image.resize_images(
 tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
imgs = []
for i in range(128):
 imgs.append(img)
print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))

'''
img = sess.run(tf.expand_dims(tf.image.resize_images(
 tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
print(img)
imgs = []
for i in range(128):
 imgs.append(img)
print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),
    feed_dict={'Placeholder:0':img}))

注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。

3.简化版本


# 连同图结构一同加载
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path)

# 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./model/')
saver.restore(sess,ckpt.model_checkpoint_path)

二、TensorFlow二进制模型加载方法

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作


# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
# 二进制读取模型文件
with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
 # 新建GraphDef文件,用于临时载入模型中的图
 graph_def = tf.GraphDef()
 # GraphDef加载模型中的图
 graph_def.ParseFromString(f.read())
 # 在空白图中加载GraphDef中的图
 tf.import_graph_def(graph_def,name='')
 # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
 # 这里的张量可以直接用于session的run方法求值了
 # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
 self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
 self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]

来源:https://www.cnblogs.com/hellcat/p/6925757.html

标签:TensorFlow,模型载入
0
投稿

猜你喜欢

  • python 中 .py文件 转 .pyd文件的操作

    2022-02-17 09:59:38
  • JavaScript 关于引用那点事

    2009-11-28 18:44:00
  • js实现简单选项卡功能

    2024-04-22 13:05:47
  • php的对象传值与引用传值代码实例讲解

    2023-11-06 08:42:37
  • PHP组合模式Composite Pattern优点与实现过程

    2023-05-29 02:10:44
  • js从Cookies里面取值的简单实现

    2024-06-21 22:22:03
  • 在Apache服务器上同时运行多个Django程序的方法

    2022-05-16 11:16:09
  • 解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题

    2022-12-06 16:17:37
  • python实现学员管理系统

    2021-05-31 07:02:45
  • 基于Python编写一个语音合成系统

    2021-10-14 03:28:16
  • python模块导入方式浅析步骤

    2023-05-13 01:08:13
  • 基于OpenCV目标跟踪实现人员计数器

    2022-11-17 15:04:03
  • Python爬虫实例扒取2345天气预报

    2021-09-27 22:38:12
  • PHP队列用法实例

    2023-10-20 12:30:49
  • python梯度下降算法的实现

    2022-01-25 11:11:09
  • 基于Python3.6中的OpenCV实现图片色彩空间的转换

    2022-05-20 14:03:13
  • ML神器:sklearn的快速使用及入门

    2023-04-17 04:42:09
  • Python数据可视化详解

    2021-10-02 19:28:55
  • python 指定源路径来解决import问题的操作

    2023-04-28 00:03:01
  • python中 ? : 三元表达式的使用介绍

    2022-07-30 00:29:44
  • asp之家 网络编程 m.aspxhome.com