tensorflow 加载部分变量的实例讲解

作者:imperfect00 时间:2023-03-27 03:19:56 

tensorflow模型保存为saver = tf.train.Saver()函数,saver.save()保存模型,代码如下:


import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
saver.save(sess,"checkpoint/model_test",global_step=1)

当我们保存模型后,我们可以通过saver.restore()来加载模型,初始化变量:


import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
saver = tf.train.Saver()
with tf.Session() as sess:
# init_op = tf.global_variables_initializer()
# sess.run(init_op)
saver.restore(sess, "checkpoint/model_test-1")
# saver.save(sess,"checkpoint/model_test",global_step=1)

神经网络训练时,有时候我们需要从预训练的模型中加载部分参数,初始化当前模型,例如加入CNN有6层,我们需要从已有的模型初始化CNN前5层参数.这可以通过saver.restore()实现.

之前我们已经介绍可以通过tf.train.Saver()的保存部分变量的方法,即需要保存的变量列表,同样的,在变量初始化的时候,我们可以对需要单独初始化的变量分别定义一个tf.train.Saver()函数,这样就可以单独对该部分变量初始化,例如下面代码,saver1用于初始化变量v1,saver2用于初始化变量v2,v3:


import tensorflow as tf

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
#saver = tf.train.Saver()
saver1 = tf.train.Saver([v1])
saver2 = tf.train.Saver([v2]+[v3])
with tf.Session() as sess:
# init_op = tf.global_variables_initializer()
# sess.run(init_op)
saver1.restore(sess, "checkpoint/model_test-1")
saver2.restore(sess, "checkpoint/model_test-1")
# saver.save(sess,"checkpoint/model_test",global_step=1)

来源:https://blog.csdn.net/u011961856/article/details/76850335

标签:tensorflow,加载,变量
0
投稿

猜你喜欢

  • mysql跨数据库复制表(在同一IP地址中)示例

    2024-01-20 00:51:11
  • asp.net中通过ALinq让Mysql操作变得如此简单

    2024-01-21 06:53:41
  • python轻松办公将100个Excel中符合条件的数据汇总到1个Excel里

    2021-08-12 03:17:11
  • PHP判断密码强度的方法详解

    2023-06-14 03:00:08
  • 微信小程序MUI导航栏透明渐变功能示例(通过改变rgba的a值实现)

    2024-05-11 09:42:52
  • Python 3.6 中使用pdfminer解析pdf文件的实现

    2023-09-02 08:34:08
  • Python自然语言处理之词干,词形与最大匹配算法代码详解

    2023-07-23 04:48:37
  • python教程之进程和线程

    2021-09-27 02:54:00
  • Golang语言学习拿捏Go反射示例教程

    2023-06-22 23:30:23
  • python实现会员管理系统

    2023-11-13 19:44:46
  • Oracle 数据库 临时数据的处理方法

    2009-07-02 11:48:00
  • Python实现文件按照日期命名的方法

    2022-10-25 19:40:09
  • Python中pymysql 模块的使用详解

    2024-01-16 21:07:25
  • QQ在线客服网页代码大全

    2008-01-17 18:28:00
  • Python判断变量名是否合法的方法示例

    2022-07-31 19:05:12
  • javascript获取select的当前值示例代码(兼容IE/Firefox/Opera/Chrome)

    2024-04-22 12:49:59
  • 无级分类的多级联动

    2020-07-02 12:53:12
  • 下雪了 javascript实现雪花飞舞

    2024-05-02 16:16:12
  • python 基于opencv操作摄像头

    2023-03-06 08:02:31
  • 在python中利用GDAL对tif文件进行读写的方法

    2022-03-25 08:08:07
  • asp之家 网络编程 m.aspxhome.com