对TensorFlow中的variables_to_restore函数详解

作者:修炼之路 时间:2022-09-11 00:49:19 

variables_to_restore函数,是TensorFlow为滑动平均值提供。之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮。我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。

1、滑动平均值模型文件的保存


import tensorflow as tf

if __name__ == "__main__":
v = tf.Variable(0.,name="v")
#设置滑动平均模型的系数
ema = tf.train.ExponentialMovingAverage(0.99)
#设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
op = ema.apply([v])
#获取变量v的名字
print(v.name)
#v:0
#创建一个保存模型的对象
save = tf.train.Saver()
sess = tf.Session()
#初始化所有变量
init = tf.initialize_all_variables()
sess.run(init)
#给变量v重新赋值
sess.run(tf.assign(v,10))
#应用平均滑动设置
sess.run(op)
#保存模型文件
save.save(sess,"./model.ckpt")
#输出变量v之前的值和使用滑动平均模型之后的值
print(sess.run([v,ema.average(v)]))
#[10.0, 0.099999905]

上面的代码,是如何来保存一个滑动平均值的模型文件,之前有介绍过滑动平均值和模型文件的保存,所以这里就不再重复了。

2、滑动平均值模型文件的读取


v = tf.Variable(1.,name="v")
#定义模型对象
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
sess = tf.Session()
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

对于模型文件的读取,在上一篇博客中有介绍过,这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。是不是感觉使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。

3、variables_to_restore函数的使用


v = tf.Variable(1.,name="v")
#滑动模型的参数的大小并不会影响v的值
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
sess = tf.Session()
saver = tf.train.Saver(ema.variables_to_restore())
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999

通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。

来源:https://blog.csdn.net/sinat_29957455/article/details/78508793

标签:TensorFlow,variables,to,restore
0
投稿

猜你喜欢

  • mysql服务启动却连接不上的解决方法

    2024-01-24 23:45:13
  • asp中常用的字符串安全处理函数集合(过滤特殊字符等)

    2011-02-20 10:40:00
  • Go语言中错误处理实例分析

    2024-02-14 18:43:10
  • AspJpeg 2.0组件使用教程(GIF篇)

    2008-12-16 19:37:00
  • python ChainMap的使用和说明详解

    2022-03-03 08:22:30
  • Python实现字典按照value进行排序的方法分析

    2022-11-03 04:27:27
  • 驱动程序无法通过使用安全套接字层(SSL)加密与 SQL Server 建立安全连接,错误:“The server selected protocol version TLS10 is not accepted by client

    2024-01-14 21:42:28
  • 关于使用python对mongo多线程更新数据

    2021-08-22 22:07:12
  • Node.js中Bootstrap-table的两种分页的实现方法

    2024-05-11 10:58:21
  • 查看mysql当前连接数的方法详解

    2024-01-21 03:24:59
  • 使用MHTML 解决 data URI scheme 的浏览器兼容问题

    2009-05-11 12:30:00
  • 如何正确合理的建立MYSQL数据库索引

    2010-10-25 20:08:00
  • Go语言LeetCode题解682棒球比赛

    2023-09-17 06:02:59
  • JavaScript中callee和caller的区别与用法实例分析

    2024-04-10 13:59:35
  • Python如何爬取qq音乐歌词到本地

    2021-03-25 19:59:45
  • Python正则表达式re.search()用法详解

    2021-08-28 03:24:46
  • php引用传值实例详解学习

    2023-11-15 06:11:30
  • CodeIgniter自定义控制器MY_Controller用法分析

    2024-05-05 09:17:36
  • python+pytest接口自动化参数关联

    2021-07-06 09:43:55
  • pytest用yaml文件编写测试用例流程详解

    2022-04-19 03:55:29
  • asp之家 网络编程 m.aspxhome.com