Pytorch转tflite方式
作者:IT修道者 发布时间:2023-08-10 03:47:29
标签:Pytorch,tflite
目标是想把在服务器上用pytorch训练好的模型转换为可以在移动端运行的tflite模型。
最直接的思路是想把pytorch模型转换为tensorflow的模型,然后转换为tflite。但是这个转换目前没有发现比较靠谱的方法。
经过调研发现最新的tflite已经支持直接从keras模型的转换,所以可以采用keras作为中间转换的桥梁,这样就能充分利用keras高层API的便利性。
转换的基本思想就是用pytorch中的各层网络的权重取出来后直接赋值给keras网络中的对应layer层的权重。
转换为Keras模型后,再通过tf.contrib.lite.TocoConverter把模型直接转为tflite.
下面是一个例子,假设转换的是一个两层的CNN网络。
import tensorflow as tf
from tensorflow import keras
import numpy as np
import torch
from torchvision import models
import torch.nn as nn
# import torch.nn.functional as F
from torch.autograd import Variable
class PytorchNet(nn.Module):
def __init__(self):
super(PytorchNet, self).__init__()
conv1 = nn.Sequential(
nn.Conv2d(3, 32, 3, 2),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2))
conv2 = nn.Sequential(
nn.Conv2d(32, 64, 3, 1, groups=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2))
self.feature = nn.Sequential(conv1, conv2)
self.init_weights()
def forward(self, x):
return self.feature(x)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight.data, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def KerasNet(input_shape=(224, 224, 3)):
image_input = keras.layers.Input(shape=input_shape)
# conv1
network = keras.layers.Conv2D(
32, (3, 3), strides=(2, 2), padding="valid")(image_input)
network = keras.layers.BatchNormalization(
trainable=False, fused=False)(network)
network = keras.layers.Activation("relu")(network)
network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)
# conv2
network = keras.layers.Conv2D(
64, (3, 3), strides=(1, 1), padding="valid")(network)
network = keras.layers.BatchNormalization(
trainable=False, fused=True)(network)
network = keras.layers.Activation("relu")(network)
network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)
model = keras.Model(inputs=image_input, outputs=network)
return model
class PytorchToKeras(object):
def __init__(self, pModel, kModel):
super(PytorchToKeras, self)
self.__source_layers = []
self.__target_layers = []
self.pModel = pModel
self.kModel = kModel
tf.keras.backend.set_learning_phase(0)
def __retrieve_k_layers(self):
for i, layer in enumerate(self.kModel.layers):
if len(layer.weights) > 0:
self.__target_layers.append(i)
def __retrieve_p_layers(self, input_size):
input = torch.randn(input_size)
input = Variable(input.unsqueeze(0))
hooks = []
def add_hooks(module):
def hook(module, input, output):
if hasattr(module, "weight"):
# print(module)
self.__source_layers.append(module)
if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel:
hooks.append(module.register_forward_hook(hook))
self.pModel.apply(add_hooks)
self.pModel(input)
for hook in hooks:
hook.remove()
def convert(self, input_size):
self.__retrieve_k_layers()
self.__retrieve_p_layers(input_size)
for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)):
print(source_layer)
weight_size = len(source_layer.weight.data.size())
transpose_dims = []
for i in range(weight_size):
transpose_dims.append(weight_size - i - 1)
if isinstance(source_layer, nn.Conv2d):
transpose_dims = [2,3,1,0]
self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(
).transpose(transpose_dims), source_layer.bias.data.numpy()])
elif isinstance(source_layer, nn.BatchNorm2d):
self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(),
source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()])
def save_model(self, output_file):
self.kModel.save(output_file)
def save_weights(self, output_file):
self.kModel.save_weights(output_file, save_format='h5')
pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224, 224, 3))
torch.save(pytorch_model, 'test.pth')
#Load the pretrained model
pytorch_model = torch.load('test.pth')
# #Time to transfer weights
converter = PytorchToKeras(pytorch_model, keras_model)
converter.convert((3, 224, 224))
# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")
# convert keras model to tflite model
converter = tf.contrib.lite.TocoConverter.from_keras_model_file(
"keras_model.h5")
tflite_model = converter.convert()
open("convert_model.tflite", "wb").write(tflite_model)
补充知识:tensorflow模型转换成tensorflow lite模型
1.把graph和网络模型打包在一个文件中
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=eval_graph_def.pb \
--input_checkpoint=checkpoint \
--output_graph=frozen_eval_graph.pb \
--output_node_names=outputs
For example:
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \
--input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
--output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \
--output_node_names=MobilenetV1/Predictions/Reshape_1
2.把第一步中生成的tensorflow pb模型转换为tf lite模型
转换前需要先编译转换工具
bazel build tensorflow/contrib/lite/toco:toco
转换分两种,一种的转换为float的tf lite,另一种可以转换为对模型进行unit8的量化版本的模型。两种方式如下:
非量化的转换:
./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ 官网给的这个路径不对
./bazel-bin/tensorflow/contrib/lite/toco/toco \
—input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \
—output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \
--input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
--inference_type=FLOAT \
--input_shape="1,224, 224,3" \
--input_array=input \
--output_array=MobilenetV1/Predictions/Reshape_1
量化方式的转换(注意,只有量化训练的模型才能进行量化的tf_lite转换):
./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \
./bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=frozen_eval_graph.pb \
--output_file=tflite_model.tflite \
--input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
--inference_type=QUANTIZED_UINT8 \
--input_shape="1,224, 224,3" \
--input_array=input \
--output_array=outputs \
--std_value=127.5 --mean_value=127.5
来源:https://blog.csdn.net/computerme/article/details/84144930
0
投稿
猜你喜欢
- 学生信息管理系统负责编辑学生信息,供大家参考,具体内容如下第一次发帖,下面通过python实现一个简单的学生信息管理系统要求如下:1.添加学
- 本文测试环境:CentOS 7 64-bit Minimal MySQL 5.7配置 yum 源在 https://dev.mysql.co
- 点云生成 3D 网格的最快方法已经用 Python 编写了几个实现来从点云中获取网格。它们中的大多数
- 对,你没看错,这是我初学 python 时的灵魂发问。我们总会在class里面看见self,但是感觉他好像也没什么用处,就是放在那里占个位子
- 方法1:import requestsurl = "http://www.xxxx.net/login"#参数拼凑,附件
- 很多开发人员在刚开始学Python 时,都考虑过像 c++ 那样来实现 singleton 模式,但后来会发现 c++ 是 c++,Pyth
- 郁闷的事来了,先看前台HTML: 购买数量: <input id="txtNum" type="text
- 在我们平常使用Python进行数据处理与分析时,在import完一大堆库之后,就是对数据进行预览,查看数据是否出现了缺失值、重复值等异常情况
- 一. 静态资源static文件放在app中确认django.contrib.staticfiles包含在INSTALLED_APPS中。在s
- 前言正则表达式是文本处理领域中的一个强大的工具,它可以让文本处理的能力呈指数级的提升,如果一款文本编辑器不支持正则表达式,那么它就算不上是一
- 前言Go语言的序列化与反序列化在工作中十分常用,在Go语言中提供了相关的解析方法去解析JSON,操作也比较简单序列化// 数据序列化func
- __str__和__repr__的异同?字符串的表示形式我们都知道,Python的内置函数repr()能够把对象用字符串的形式表达出来,方便
- sql语句/*MySQL 消除重复行的一些方法---Chu Minfei---2010-08-12 22:49:44.660--引用转载请注
- 一、前言最近忙里偷闲,做了一个部署数据库及IIS网站站点的WPF应用程序工具。 二、内容此工具的目的是:根据.sql文件在本机上部
- 复制目录: 包含多层子目录方法: 递归, 深度遍历,广度遍历深度遍历&广度遍历:思路:1.获得源目录子级目录,并设置目标目录的子级路
- Turtle库是Python语言中一个很流行的绘制图像的函数库,想象一个小乌龟,在一个横轴为x、纵轴为y的坐标系原点,(0,0)位置开始,它
- ASP是Web上的客户机/服务器结构的中间层,虽然它使用脚本语言(Java Script,VB Script等)编写,程序代码在服务器上运行
- 为什么ASP.NET Core采用Main方法?需要记住的最重要的一点是,ASP.NET Core Web 应用程序最初作为控制台应用程序启
- 本文实例讲述了Python中map和列表推导效率比较。分享给大家供大家参考。具体分析如下:直接来测试代码吧:#!/usr/bin/env p
- 下面有两种方法都可以:import numpy as npa=np.asarray([[10,20],[101,201]])# a=a[:,