Keras模型转成tensorflow的.pb操作

作者:VickyD1023 时间:2023-12-22 13:10:34 

Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码


from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os

base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)

model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
 freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
 output_names = output_names or []
 output_names += [v.op.name for v in tf.global_variables()]
 input_graph_def = graph.as_graph_def()
 if clear_devices:
  for node in input_graph_def.node:
   node.device = ""
 frozen_graph = convert_variables_to_constants(session, input_graph_def,
            output_names, freeze_var_names)
 return frozen_graph

output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)

print('input is :', model.input.name)
print ('output is:', model.output.name)

sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])

from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

补充知识:keras h5 model 转换为tflite

在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。

环境

tensorflow 1.12.0

python 3.6.5

h5 model saved by `model.save('tf.h5')`

直接转换


`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先转成pb再转tflite


```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \

--output_file=tf.tflite \
--graph_def_file=tf.pb \
--input_arrays=convolution2d_1_input \
--output_arrays=dense_3/BiasAdd \
--input_shape=1,3,448,448
```

参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应

官网是说需要用到tenorboard来查看,一个比较trick的方法

先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。

来源:https://blog.csdn.net/q6324266/article/details/85262438

标签:Keras,tensorflow,.pb
0
投稿

猜你喜欢

  • Python读取文件内容的三种常用方式及效率比较

    2023-08-29 23:46:00
  • Mootools 1.2教程(12)——用Drag.Move实现拖拽和拖放

    2008-12-05 12:29:00
  • 使用auto.js实现自动化每日打卡功能

    2024-04-16 08:47:38
  • python图形界面教程Tkinter详解

    2021-01-08 04:27:47
  • Python如何处理大数据?3个技巧效率提升攻略(推荐)

    2022-04-02 10:03:03
  • php中关于hook钩子函数底层理解

    2023-06-12 06:49:55
  • Innodb表select查询顺序

    2024-01-16 03:32:40
  • 在PB中如何让用户只能修改新增的数据

    2023-11-27 15:59:52
  • vue中的路由传值与重调本路由改变参数

    2024-04-27 16:10:12
  • Pycharm设置界面全黑的方法

    2021-09-15 11:13:51
  • python3.X 抓取火车票信息【修正版】

    2022-01-26 01:24:53
  • python装饰器三种装饰模式的简单分析

    2022-06-26 17:29:46
  • 在ASP.NET 2.0中操作数据之五十二:使用FileUpload上传文件

    2023-07-07 04:19:18
  • 使用Spring Boot实现操作数据库的接口的过程

    2024-01-25 02:02:49
  • 5个充满想象力的Web调色板

    2008-08-02 12:55:00
  • python 日志模块 日志等级设置失效的解决方案

    2022-01-25 07:27:19
  • 对python3 urllib包与http包的使用详解

    2022-08-04 15:20:14
  • ASP、PHP与javascript根据时段切换CSS皮肤的代码

    2008-09-01 17:26:00
  • Python之 requests的使用(一)

    2023-01-06 16:02:09
  • 用js来解决ajax读取页面乱码

    2024-04-18 10:56:04
  • asp之家 网络编程 m.aspxhome.com