浅谈tensorflow模型保存为pb的各种姿势

作者:googler_offer 时间:2023-02-01 21:19:05 

一,直接保存pb

1, 首先我们当然可以直接在tensorflow训练中直接保存为pb为格式,保存pb的好处就是使用场景是实现创建模型与使用模型的解耦,使得创建模型与使用模型的解耦,使得前向推导inference代码统一。另外的好处就是保存为pb的时候,模型的变量会变成固定的,导致模型的大小会大大减小。

这里稍稍解释下pb:是MetaGraph的protocol buffer格式的文件,MetaGraph包括计算图,数据流,以及相关的变量和输入输出

主要使用tf.SavedModelBuilder来完成这个工作,并且可以把多个计算图保存到一个pb文件中,如果有多个MetaGraph,那么只会保留第一个MetaGraph的版本号。

保持pb的文件代码:


import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
# 这里的输出需要加上name属性
op = tf.add(xy, b, name='op_to_store')

sess.run(tf.global_variables_initializer())

# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

# 测试 OP
feed_dict = {x: 10, y: 3}
print(sess.run(op, feed_dict))

# 写入序列化的 PB 文件
with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
 f.write(constant_graph.SerializeToString())

# 输出
# INFO:tensorflow:Froze 1 variables.
# Converted 1 variables to const ops.
# 31

其实主要是:


# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

# 写入序列化的 PB 文件
with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
 f.write(constant_graph.SerializeToString())

1.1 加载测试代码


from tensorflow.python.platform import gfile

sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='') # 导入计算图

# 需要有一个初始化的过程
sess.run(tf.global_variables_initializer())

# 需要先复原变量
print(sess.run('b:0'))
# 1

# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

op = sess.graph.get_tensor_by_name('op_to_store:0')

ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26

2,第二种就是采用上述的那API来进行保存


import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
# 这里的输出需要加上name属性
op = tf.add(xy, b, name='op_to_store')

sess.run(tf.global_variables_initializer())

# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

# 测试 OP
feed_dict = {x: 10, y: 3}
print(sess.run(op, feed_dict))

# 写入序列化的 PB 文件
with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
 f.write(constant_graph.SerializeToString())

# INFO:tensorflow:Froze 1 variables.
# Converted 1 variables to const ops.
# 31

# 官网有误,写成了 saved_model_builder
builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
# 构造模型保存的内容,指定要保存的 session,特定的 tag,
# 输入输出信息字典,额外的信息
builder.add_meta_graph_and_variables(sess,
         ['cpu_server_1'])

# 添加第二个 MetaGraphDef
#with tf.Session(graph=tf.Graph()) as sess:
# ...
# builder.add_meta_graph([tag_constants.SERVING])
#...

builder.save() # 保存 PB 模型

核心就是采用了:


# 官网有误,写成了 saved_model_builder
builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
# 构造模型保存的内容,指定要保存的 session,特定的 tag,
# 输入输出信息字典,额外的信息
builder.add_meta_graph_and_variables(sess,
         ['cpu_server_1'])

2.1 对应的测试代码为:


with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ['cpu_1'], pb_file_path+'savemodel')
sess.run(tf.global_variables_initializer())

input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

op = sess.graph.get_tensor_by_name('op_to_store:0')

ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 只需要指定要恢复模型的 session,模型的 tag,模型的保存路径即可,使用起来更加简单

这样和之前的导入pb模型一样,也是要知道tensor的name,那么如何在不知道tensor name的情况下使用呢,给add_meta_graph_and_variables方法传入第三个参数,signature_def_map即可。

二,从ckpt进行加载

使用tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:


import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
sess.run(init_op)
print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
print("v2:", sess.run(v2))
saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
print("Model saved in file:", saver_path)

浅谈tensorflow模型保存为pb的各种姿势

checkpoint是检查点的文件,文件保存了一个目录下所有的模型文件列表

model.ckpt.meta文件保存了Tensorflow计算图的结果,可以理解为神经网络的网络结构,该文件可以被tf.train.import_meta_graph加载到当前默认的图来使用

ckpt.data是保存模型中每个变量的取值

方法一, tensorflow提供了convert_variables_to_constants()方法,改方法可以固化模型结构,将计算图中的变量取值以常量的形式保存

ckpt转换pb格式过程如下:

1,通过传入ckpt模型的路径得到模型的图和变量数据

2,通过import_meta_graph导入模型中的图

3,通过saver.restore从模型中恢复图中各个变量的数据

4,通过graph_util.convert_variables_to_constants将模型持久化


import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.pyton.platform import gfile

def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "InceptionV3/Logits/SpatialSqueeze"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图

with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess=sess,
  input_graph_def=input_graph_def,# 等于:sess.graph_def
  output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开

with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
  f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

# for op in graph.get_operations():
 #  print(op.name, op.values())

函数freeze_graph中,最重要的就是指定输出节点的名称,这个节点名称是原模型存在的结点,注意节点名称与张量名称的区别:

如:“input:0”是张量的名称,而“input”表示的是节点的名称

源码中通过graph = tf.get_default_graph()获得默认图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此就必须执行tf.train.import_meta_graph,再执行tf.get_default_graph()

1.2 一个小工具

tensorflow打印pb模型的所有节点


from tensorflow.python.framework import tensor_util
from google.protobuf import text_format
import tensorflow as tf
from tensorflow.python.platform import gfile
from tensorflow.python.framework import tensor_util

pb_path = './model.pb'

with tf.Session() as sess:
with gfile.FastGFile(pb_path,'rb') as f:
 graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())
 tf.import_graph_def(graph_def,name='')
 for i,n in enumerate(graph_def.node):
  print("Name of the node -%s"%n.name)
tensorflow打印ckpt的所有节点

from tensorflow.python import pywrap_tensorflow
checkpoint_path = './_checkpoint/hed.ckpt-130'

reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name:",key)

方法二,除了上述办法外还有一种是需要通过源码的,这样既可以得到输出节点,还可以自定义输入节点。


import tensorflow as tf

def model(input):
net = tf.layers.conv2d(input,filters=32,kernel_size=3)
net = tf.layers.batch_normalization(net,fused=False)
net = tf.layers.separable_conv2d(net,32,3)
net = tf.layers.conv2d(net,filters=32,kernel_size=3,name='output')

return net

input_node = tf.placeholder(tf.float32,[1,480,480,3],name = 'image')
output_node_names = 'head_neck_count/BiasAdd'
ckpt = ckpt_path
pb = pb_path

with tf.Session() as sess:
model1 = model(input_node)
sess.run(tf.global_variables_initializer())
output_node_names = 'output/BiasAdd'

input_graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names.split(','))

with tf.gfile.GFile(pb,'wb') as f:
f.write(output_graph_def.SerializeToString())

注意:

节点名称和张量名称区别

类似于output是节点名称

类似于output:0是张量名称

方法三,其实是方法一的延伸可以配合tensorflow自带的一些工具来进行完成

freeze_graph

总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):

1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)

2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。

3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False

4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。

5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。

6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all

7、filename_tensor_name:(可选)已弃用。默认:save/Const:0

8、output_graph:(必选)用来保存整合后的模型输出文件。

9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)

10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。

所以还是建议选择方法三

导出pb后的测试代码如下:下图是比较完成的测试代码与导出代码。


# -*-coding: utf-8 -*-
"""
@Project: tensorflow_models_nets
@File : convert_pb.py
@Author : panjq
@E-mail : pan_jinquan@163.com
@Date : 2018-08-29 17:46:50
@info :
-通过传入 CKPT 模型的路径得到模型的图和变量数据
-通过 import_meta_graph 导入模型中的图
-通过 saver.restore 从模型中恢复图中各个变量的数据
-通过 graph_util.convert_variables_to_constants 将模型持久化
"""

import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util

resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3

def freeze_graph_test(pb_path, image_path):
'''
:param pb_path:pb文件的路径
:param image_path:测试图片的路径
:return:
'''
with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
  output_graph_def.ParseFromString(f.read())
  tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())

# 定义输入的张量名称,对应网络结构的输入张量
  # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
  input_image_tensor = sess.graph.get_tensor_by_name("input:0")
  input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
  input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")

# 定义输出的张量名称
  output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

# 读取测试图片
  im=read_image(image_path,resize_height,resize_width,normalization=True)
  im=im[np.newaxis,:]
  # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
  # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
  out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
             input_keep_prob_tensor:1.0,
             input_is_training_tensor:False})
  print("out:{}".format(out))
  score = tf.nn.softmax(out, name='pre')
  class_id = tf.argmax(score, 1)
  print "pre class_id:{}".format(sess.run(class_id))

def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "InceptionV3/Logits/SpatialSqueeze"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess=sess,
  input_graph_def=sess.graph_def,# 等于:sess.graph_def
  output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开

with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
  f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

# for op in sess.graph.get_operations():
 #  print(op.name, op.values())

def freeze_graph2(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "InceptionV3/Logits/SpatialSqueeze"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图

with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess=sess,
  input_graph_def=input_graph_def,# 等于:sess.graph_def
  output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开

with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
  f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

# for op in graph.get_operations():
 #  print(op.name, op.values())

if __name__ == '__main__':
# 输入ckpt模型路径
input_checkpoint='models/model.ckpt-10000'
# 输出pb模型的路径
out_pb_path="models/pb/frozen_model.pb"
# 调用freeze_graph将ckpt转为pb
freeze_graph(input_checkpoint,out_pb_path)

# 测试pb模型
image_path = 'test_image/animal.jpg'
freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

来源:https://blog.csdn.net/googler_offer/article/details/88577458

标签:tensorflow,模型,保存,pb
0
投稿

猜你喜欢

  • 用javascript来实现仿gogle动画导航

    2007-11-30 14:15:00
  • vue+elementUi图片上传组件使用详解

    2024-05-10 14:14:49
  • Python实现图片转字符画的示例代码

    2021-07-13 19:21:03
  • python区分不同数据类型的方法

    2022-03-30 23:37:41
  • Linux下安装MySQL5.7.19问题小结

    2024-01-16 06:21:37
  • layui 上传文件_批量导入数据UI的方法

    2024-05-22 10:36:41
  • JS字符串拼接的几种方式(最新推荐)

    2024-04-10 16:11:41
  • Pytorch数据类型与转换(torch.tensor,torch.FloatTensor)

    2023-03-31 13:32:36
  • 9种使用Chrome Firefox 自带调试工具调试javascript技巧

    2023-07-19 01:03:48
  • Python变量作用域LEGB用法解析

    2022-12-05 19:18:22
  • 深刻理解Oracle数据库的启动和关闭

    2010-07-26 13:08:00
  • Django用数据库表反向生成models类知识点详解

    2024-01-25 15:19:20
  • Python打包为exe详细教程

    2023-08-23 03:00:21
  • 利用Python找出删除自己微信的好友并将他们自动化删除

    2022-05-23 00:10:13
  • 利用Python打造一个多人聊天室的示例详解

    2023-04-10 15:22:11
  • Python时间的精准正则匹配方法分析

    2022-12-10 12:59:28
  • Python导入Excel表格数据并以字典dict格式保存的操作方法

    2023-05-25 17:58:37
  • 通过按钮实时切换CSS样式 实现CSS换肤的实例

    2008-07-17 12:55:00
  • MySql数据引擎简介与选择方法

    2024-01-28 12:04:29
  • 疯狂上涨的Python 开发者应从2.x还是3.x着手?

    2021-10-25 16:41:54
  • asp之家 网络编程 m.aspxhome.com