详解tensorflow实现迁移学习实例
作者:疯女孩爱飞 时间:2022-02-06 01:43:22
本文主要是总结利用tensorflow实现迁移学习的基本步骤。
所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。
持久化
首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。
保存图
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, "model.ckpt")
加载图
saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess, "model.ckpt")
迁移学习
第一步: 读取加载已经训练好的模型
在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])
第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。
def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})
# 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
bottlenect_values = np.squeeze(bottlenect_values)
return bottlenect_values
该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。
第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)
以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持脚本之家。
来源:http://blog.csdn.net/ustbfym/article/details/78201575
标签:tensorflow,迁移
0
投稿
猜你喜欢
MySQL中Replace语句用法实例详解
2024-01-15 03:26:28
Anaconda+vscode+pytorch环境搭建过程详解
2022-04-06 01:37:19
python高级特性和高阶函数及使用详解
2022-09-17 20:13:50
Python简直是万能的,这5大主要用途你一定要知道!(推荐)
2021-03-16 16:20:31
Python pygame实现中国象棋单机版源码
2021-04-15 05:34:16
Python中的Function定义方法第1/2页
2021-05-10 20:33:49
PHP同时连接多个mysql数据库示例代码
2023-11-23 21:12:28
python 读取数据库并绘图的实例
2024-01-27 10:43:33
详解基于Jupyter notebooks采用sklearn库实现多元回归方程编程
2022-04-19 21:35:31
在ipython notebook中使用argparse方式
2021-11-17 08:58:41
详解Mysql查询条件中字符串尾部有空格也能匹配上的问题
2024-01-13 11:06:50
python爬取基于m3u8协议的ts文件并合并
2021-11-03 16:44:45
mysql alter语句用法实例
2024-01-25 12:32:53
XML+ JS创建树形菜单
2013-08-22 08:30:17
基于python代码批量处理图片resize
2022-03-18 23:06:58
Python 实现将大图切片成小图,将小图组合成大图的例子
2023-05-23 18:08:36
详解nodejs内置模块
2024-05-03 15:54:20
解决ele ui 表格表头太长问题的实现
2024-05-13 09:44:00
简单介绍Python中的readline()方法的使用
2023-11-02 13:34:30
基于Python检测动态物体颜色过程解析
2022-03-20 09:07:30