TensorFlow的权值更新方法

作者:朂嘼 时间:2022-12-24 21:41:08 

一. MovingAverage权值滑动平均更新

1.1 示例代码:


def create_target_q_network(self,state_dim,action_dim,net):
 state_input = tf.placeholder("float",[None,state_dim])
 action_input = tf.placeholder("float",[None,action_dim])

ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
 target_update = ema.apply(net)
 target_net = [ema.average(x) for x in net]

layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
 layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
 q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])

return state_input,action_input,q_value_output,target_update

def update_target(self):
 self.sess.run(self.target_update)

其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )

第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;

第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。

更新公式为:


shadow_variable = decay * shadow_variable + (1 - decay) * variable

在上述代码段中,target_net是shadow_variable,net是variable

1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)

var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。

函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。

1.3 tf.train.ExponentialMovingAverage.average(var)

用于获取var的滑动平均结果。

二. tf.train.Optimizer更新网络权值

2.1 tf.train.Optimizer

tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。

此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:

1.利用tf.train.Optimizer.compute_gradients计算梯度

2.对梯度进行自定义处理

3.利用tf.train.Optimizer.apply_gradients更新权值


tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)

返回一个(梯度,权值)的列表对。

tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)

返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)

2.2 其它

此外,tensorflow还提供了其它计算梯度的方法:

• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)

该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。

其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。

• tf.stop_gradient(input, name=None)

该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:


cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|

可以事先声明


y=tf.stop_gradient(r+Q next r+Qnext)

来源:https://blog.csdn.net/GH234505/article/details/54976696

标签:TensorFlow,权值,更新
0
投稿

猜你喜欢

  • mysql查找删除重复数据并只保留一条实例详解

    2024-06-05 09:52:53
  • python中pyqtgraph知识点总结

    2022-02-23 10:24:30
  • 一些建站常用简单html代码

    2008-06-01 13:17:00
  • PHP封装cURL工具类与应用示例

    2023-10-18 11:57:36
  • Python实现排序方法常见的四种

    2022-02-18 08:06:15
  • 利用hasOwnProperty给数组去重的面试题分享

    2023-08-06 20:48:37
  • pytorch显存一直变大的解决方案

    2021-03-03 00:03:09
  • MySQL数据库的23个注意事项

    2024-01-23 11:26:06
  • 20个优秀网站助你征服CSS[译]

    2008-09-21 13:21:00
  • 简述php环境搭建与配置

    2023-11-15 09:08:28
  • Python使用Rich type和TinyDB构建联系人通讯录

    2023-07-13 10:33:22
  • 深入了解Python enumerate和zip

    2021-11-15 12:08:23
  • 使用python实现kmean算法

    2022-09-17 13:07:22
  • Flask核心机制之上下文源码剖析

    2022-07-29 18:23:28
  • Python3如何使用tabulate打印数据

    2021-04-17 15:09:26
  • SQL2005CLR函数扩展-数据导出的实现详解

    2024-01-25 11:59:57
  • python 如何使用find和find_all爬虫、找文本的实现

    2023-09-30 02:01:46
  • 桌面中心(二)数据库写入

    2023-11-18 12:26:15
  • MySQL优化配置文件my.ini(discuz论坛)

    2024-01-13 23:34:43
  • python函数的两种嵌套方法使用

    2022-01-14 08:06:58
  • asp之家 网络编程 m.aspxhome.com