tensorflow实现从.ckpt文件中读取任意变量

作者:黑龙江小伙er 时间:2023-01-04 15:39:40 

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:


import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow

file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)


def fix_variables(self, sess, pretrained_model):
print('Fix VGG16 layers..')
with tf.variable_scope('Fix_VGG16') as scope:
 with tf.device("/cpu:0"):
  # fix the vgg16 issue from conv weights to fc weights
  # fix RGB to BGR
  fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
  fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
  conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
  restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv,
                 self._scope + "/fc7/weights": fc7_conv,
                 self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
  restorer_fc.restore(sess, pretrained_model) #恢复这些变量

sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv,
            self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
  sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv,
            self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
  sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'],
            tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数


from tensorflow.python import pywrap_tensorflow

model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量

file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)

sess=tf.Session()

my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
 if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
   print('Variables restored: %s' % v.name)
   variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
 elif v.name.split(':')[0] in var_to_shape_map:
   print('Variables restored: %s' % v.name)
   variables_to_restore[v.name]=v

restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:


fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)

restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
              "vgg_16/fc7/weights": fc7_conv,
              })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

来源:https://blog.csdn.net/weixin_39999955/article/details/80937112

标签:tensorflow,ckpt,变量
0
投稿

猜你喜欢

  • Mootools常用方法扩展(二)

    2009-01-11 18:22:00
  • Python input输入超时选择默认值自动跳过问题

    2023-02-22 07:22:40
  • 教你如何6秒钟往MySQL插入100万条数据的实现

    2024-01-19 02:17:35
  • 完美解决webstorm启动索引文件卡死的问题

    2022-04-05 05:52:29
  • Typescript中extends关键字的基本使用

    2024-06-18 01:03:50
  • Java+Spring+MySql环境中安装和配置MyBatis的教程

    2024-01-12 23:23:48
  • pyenv与virtualenv安装实现python多版本多项目管理

    2022-12-19 23:50:04
  • Python3.x检查内存可用大小的两种实现

    2022-03-24 07:17:03
  • python 详解turtle画爱心代码

    2022-05-09 20:44:58
  • XML卷之实战锦囊(2):动态查询

    2008-09-05 17:20:00
  • Python轻松搞定视频剪辑重复性工作问题

    2022-12-18 16:06:54
  • Pytorch之finetune使用详解

    2021-08-31 20:41:44
  • 利用js将ajax获取到的后台数据动态加载至网页中的方法

    2024-04-16 10:37:03
  • javascript面向对象编程(二)

    2008-03-07 12:59:00
  • 下载文件个别浏览器文件名乱码解决办法

    2024-04-17 10:05:04
  • 详解Python编程中基本的数学计算使用

    2022-12-12 13:52:04
  • 鼠年发几张可爱老鼠的表情gif

    2008-01-29 12:50:00
  • 基于python实现简单网页服务器代码实例

    2023-06-26 07:56:44
  • sqlserver 动态创建临时表的语句分享

    2012-01-29 17:54:37
  • MATLAB中print函数使用示例详解

    2023-11-18 04:18:04
  • asp之家 网络编程 m.aspxhome.com