TensorFlow梯度求解tf.gradients实例

作者:yqtaowhu 时间:2023-08-16 17:26:03 

我就废话不多说了,直接上代码吧!


import tensorflow as tf

w1 = tf.Variable([[1,2]])
w2 = tf.Variable([[3,4]])

res = tf.matmul(w1, [[2],[1]])

grads = tf.gradients(res,[w1])

with tf.Session() as sess:
tf.global_variables_initializer().run()
print sess.run(res)
print sess.run(grads)

输出结果为:


[[4]]
[array([[2, 1]], dtype=int32)]

可以这样看res与w1有关,w1的参数设为[a1,a2],则:

2*a1 + a2 = res

所以res对a1,a2求导可得 [[2,1]]为w1对应的梯度信息。


import tensorflow as tf
def gradient_clip(gradients, max_gradient_norm):
"""Clipping gradients of a model."""
clipped_gradients, gradient_norm = tf.clip_by_global_norm(
  gradients, max_gradient_norm)
gradient_norm_summary = [tf.summary.scalar("grad_norm", gradient_norm)]
gradient_norm_summary.append(
 tf.summary.scalar("clipped_gradient", tf.global_norm(clipped_gradients)))

return clipped_gradients
w1 = tf.Variable([[3.0,2.0]])
# w2 = tf.Variable([[3,4]])
params = tf.trainable_variables()
res = tf.matmul(w1, [[3.0],[1.]])
opt = tf.train.GradientDescentOptimizer(1.0)
grads = tf.gradients(res,[w1])
clipped_gradients = gradient_clip(grads,2.0)
global_step = tf.Variable(0, name='global_step', trainable=False)
#update = opt.apply_gradients(zip(clipped_gradients,params), global_step=global_step)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print sess.run(res)
print sess.run(grads)
print sess.run(clipped_gradients)

来源:https://blog.csdn.net/taoyanqi8932/article/details/77602721

标签:TensorFlow,梯度,tf.gradients
0
投稿

猜你喜欢

  • python实现凯撒密码、凯撒加解密算法

    2023-08-27 17:49:22
  • 关于scipy.optimize函数使用及说明

    2022-10-19 04:24:04
  • 关于numpy中eye和identity的区别详解

    2021-11-18 14:33:08
  • ASP实现长文章自动分页的函数代码

    2008-10-10 17:09:00
  • Python 实现毫秒级淘宝抢购脚本的示例代码

    2023-05-10 19:50:10
  • Python3转换html到pdf的不同解决方案

    2021-10-03 19:50:03
  • JavaScript中String.prototype用法实例

    2024-04-22 22:18:12
  • 超详细注释之OpenCV制作图像Mask

    2021-10-20 14:15:34
  • python实现在内存中读写str和二进制数据代码

    2022-03-30 04:55:11
  • ASP中如何判断字符串中是否包含字母和数字

    2009-07-10 13:12:00
  • Python3使用requests发闪存的方法

    2021-06-09 16:07:20
  • 利用Python绘制虎年烟花秀

    2022-10-08 06:03:49
  • go语言数组及结构体继承和初始化示例解析

    2024-05-08 10:22:35
  • TCP协议用在python和wifi模块之间详解

    2021-02-04 05:43:08
  • 详解Mysql case then使用

    2024-01-25 05:38:19
  • Python算法输出1-9数组形成的结果为100的所有运算式

    2022-05-02 22:45:48
  • Django+RestFramework API接口及接口文档并返回json数据操作

    2021-05-29 21:43:57
  • Python生命游戏实现原理及过程解析(附源代码)

    2023-07-16 18:02:13
  • Javascript函数类型判断解决方案

    2009-08-27 15:32:00
  • asp如何刪除客户端的Cookies?

    2010-05-18 18:25:00
  • asp之家 网络编程 m.aspxhome.com