tensorflow创建变量以及根据名称查找变量

作者:lijiao 时间:2023-08-13 10:13:06 

环境:Ubuntu14.04,tensorflow=1.4(bazel源码安装),Anaconda python=3.6

声明变量主要有两种方法:tf.Variabletf.get_variable,二者的最大区别是:

(1) tf.Variable是一个类,自带很多属性函数;而 tf.get_variable是一个函数;
(2) tf.Variable只能生成独一无二的变量,即如果给出的name已经存在,则会自动修改生成新的变量name;
(3) tf.get_variable可以用于生成共享变量。默认情况下,该函数会进行变量名检查,如果有重复则会报错。当在指定变量域中声明可

以变量共享时,可以重复使用该变量(例如RNN中的参数共享)。
下面给出简单的的示例程序:


import tensorflow as tf

with tf.variable_scope('scope1',reuse=tf.AUTO_REUSE) as scope1:
 x1 = tf.Variable(tf.ones([1]),name='x1')
 x2 = tf.Variable(tf.zeros([1]),name='x1')
 y1 = tf.get_variable('y1',initializer=1.0)
 y2 = tf.get_variable('y1',initializer=0.0)
 init = tf.global_variables_initializer()
 with tf.Session() as sess:
   sess.run(init)
   print(x1.name,x1.eval())
   print(x2.name,x2.eval())
   print(y1.name,y1.eval())
   print(y2.name,y2.eval())

输出结果为:


scope1/x1:0 [ 1.]
scope1/x1_1:0 [ 0.]
scope1/y1:0 1.0
scope1/y1:0 1.0

1. tf.Variable(…)

tf.Variable(…)使用给定初始值来创建一个新变量,该变量会默认添加到 graph collections listed in collections, which defaults to [GraphKeys.GLOBAL_VARIABLES]。

如果trainable属性被设置为True,该变量同时也会被添加到graph collection GraphKeys.TRAINABLE_VARIABLES.


# tf.Variable
__init__(
 initial_value=None,
 trainable=True,
 collections=None,
 validate_shape=True,
 caching_device=None,
 name=None,
 variable_def=None,
 dtype=None,
 expected_shape=None,
 import_scope=None,
 constraint=None
)

2. tf.get_variable(…)

tf.get_variable(…)的返回值有两种情形:

使用指定的initializer来创建一个新变量;
当变量重用时,根据变量名搜索返回一个由tf.get_variable创建的已经存在的变量;


get_variable(
 name,
 shape=None,
 dtype=None,
 initializer=None,
 regularizer=None,
 trainable=True,
 collections=None,
 caching_device=None,
 partitioner=None,
 validate_shape=True,
 use_resource=None,
 custom_getter=None,
 constraint=None
)

3. 根据名称查找变量

在创建变量时,即使我们不指定变量名称,程序也会自动进行命名。于是,我们可以很方便的根据名称来查找变量,这在抓取参数、finetune模型等很多时候都很有用。

示例1:

通过在tf.global_variables()变量列表中,根据变量名进行匹配搜索查找。 该种搜索方式,可以同时找到由tf.Variable或者tf.get_variable创建的变量。


import tensorflow as tf

x = tf.Variable(1,name='x')
y = tf.get_variable(name='y',shape=[1,2])
for var in tf.global_variables():
 if var.name == 'x:0':
   print(var)

示例2:

利用get_tensor_by_name()同样可以获得由tf.Variable或者tf.get_variable创建的变量。
需要注意的是,此时获得的是Tensor, 而不是Variable,因此 x不等于x1.


import tensorflow as tf

x = tf.Variable(1,name='x')
y = tf.get_variable(name='y',shape=[1,2])

graph = tf.get_default_graph()

x1 = graph.get_tensor_by_name("x:0")
y1 = graph.get_tensor_by_name("y:0")

示例3:

针对tf.get_variable创建的变量,可以利用变量重用来直接获取已经存在的变量。


with tf.variable_scope("foo"):
 bar1 = tf.get_variable("bar", (2,3)) # create

with tf.variable_scope("foo", reuse=True):
 bar2 = tf.get_variable("bar") # reuse

with tf.variable_scope("", reuse=True): # root variable scope
 bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)

print((bar1 is bar2) and (bar2 is bar3))
标签:tensorflow,变量
0
投稿

猜你喜欢

  • Pandas缺失值填充 df.fillna()的实现

    2023-11-24 00:01:41
  • python办公之python编辑word

    2022-03-31 08:48:40
  • 浅谈pycharm使用及设置方法

    2023-12-18 21:17:47
  • 初探TensorFLow从文件读取图片的四种方式

    2021-08-06 06:04:34
  • SQL点滴24 监测表的变化

    2011-09-30 11:38:41
  • 对python mayavi三维绘图的实现详解

    2022-04-29 03:17:50
  • PHP实时统计中文字数和区别

    2023-07-13 10:44:01
  • FSO读取BMP,JPG,PNG,GIF图像文件信息的函数

    2007-08-04 09:56:00
  • python实现在线翻译功能

    2023-06-02 22:12:38
  • python用Pygal如何生成漂亮的SVG图像详解

    2022-12-12 21:45:22
  • pyqt5-tools安装失败的详细处理方法

    2021-08-22 21:34:00
  • selenium python 实现基本自动化测试的示例代码

    2021-05-04 06:23:07
  • 用VBS语言实现的网页计算器源代码

    2007-12-26 17:09:00
  • 学python需要去培训机构吗

    2022-02-12 07:46:29
  • 没编程基础可以学python吗

    2023-11-27 23:12:49
  • 解决python打不开文件(文件不存在)的问题

    2021-10-15 02:39:46
  • 一些关于SQL2005+ASP.NET2.0的问题

    2007-09-23 13:01:00
  • CentOS中升级Python版本的方法详解

    2021-08-22 20:22:30
  • 浅析Python中线程以及线程阻塞

    2022-03-06 22:14:23
  • 计划备份mysql数据库

    2009-03-09 14:34:00
  • asp之家 网络编程 m.aspxhome.com