解决pytorch中的kl divergence计算问题

作者:jingxian 时间:2023-11-12 11:02:00 

偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中计算结果不同,平时没有注意到,记录下

一篇关于KL散度、JS散度以及交叉熵对比的文章

kl divergence 介绍

KL散度( Kullback–Leibler divergence),又称相对熵,是描述两个概率分布 P 和 Q 差异的一种方法。计算公式:

解决pytorch中的kl divergence计算问题

可以发现,P 和 Q 中元素的个数不用相等,只需要两个分布中的离散元素一致。

举个简单例子:

两个离散分布分布分别为 P 和 Q

P 的分布为:{1,1,2,2,3}

Q 的分布为:{1,1,1,1,1,2,3,3,3,3}

我们发现,虽然两个分布中元素个数不相同,P 的元素个数为 5,Q 的元素个数为 10。但里面的元素都有 “1”,“2”,“3” 这三个元素。

当 x = 1时,在 P 分布中,“1” 这个元素的个数为 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 这个元素的个数为 5,故 Q(x = 1) = 5/10 = 0.5

同理,

当 x = 2 时,P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1

当 x = 3 时,P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4

把上述概率带入公式:

解决pytorch中的kl divergence计算问题

至此,就计算完成了两个离散变量分布的KL散度。

pytorch 中的 kl_div 函数

pytorch中有用于计算kl散度的函数 kl_div


torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')

解决pytorch中的kl divergence计算问题

计算 D (p||q)

1、不用这个函数的计算结果为:

解决pytorch中的kl divergence计算问题

与手算结果相同

2、使用函数:

(这是计算正确的,结果有差异是因为pytorch这个函数中默认的是以e为底)

解决pytorch中的kl divergence计算问题

注意:

1、函数中的 p q 位置相反(也就是想要计算D(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log

2、reduction 是选择对各部分结果做什么操作,默认为取平均数,这里选择求和

好别扭的用法,不知道为啥官方把它设计成这样

补充:pytorch 的KL divergence的实现

看代码吧~


import torch.nn.functional as F
# p_logit: [batch, class_num]
# q_logit: [batch, class_num]
def kl_categorical(p_logit, q_logit):
   p = F.softmax(p_logit, dim=-1)
   _kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1)
                                 - F.log_softmax(q_logit, dim=-1)), 1)
   return torch.mean(_kl)

来源:https://blog.csdn.net/wwyy2018/article/details/101599862

标签:pytorch,kl,divergence
0
投稿

猜你喜欢

  • oracle数据库添加或删除一列的sql语句

    2012-06-06 19:46:54
  • Django框架 信号调度原理解析

    2022-05-14 20:04:46
  • AjaxUpLoad.js实现文件上传

    2024-05-11 09:42:07
  • PHP 图片上传代码

    2024-05-22 10:05:49
  • Python存储读取HDF5文件代码解析

    2021-07-24 22:33:09
  • 快速实现基于Python的微信聊天机器人示例代码

    2022-05-30 19:22:50
  • vue 自定义右键样式的实例代码

    2023-07-02 16:33:34
  • css样式表滤镜全接触

    2007-10-26 12:48:00
  • 一起感受HTML5和CSS3的能量[译]

    2009-09-04 16:29:00
  • Python Socket TCP双端聊天功能实现过程详解

    2022-03-13 02:25:44
  • 在python中bool函数的取值方法

    2021-10-06 00:47:37
  • 使用canal监控mysql数据库实现elasticsearch索引实时更新问题

    2024-01-20 22:48:39
  • asp如何在ADO中使用存储查询?

    2010-06-17 12:52:00
  • python re正则表达式模块(Regular Expression)

    2021-01-26 20:22:26
  • Python数据可视化图实现过程详解

    2022-08-30 19:24:17
  • 快速掌握如何使用SQL Server来过滤数据

    2009-01-15 13:27:00
  • windows 7安装ORACLE 10g客户端的方法分享

    2012-07-11 15:36:18
  • 利用SQL语言有没有办法查到表中哪些记录中的全部

    2009-04-10 18:29:00
  • Vue插槽slot详细介绍(对比版本变化,避免踩坑)

    2024-05-13 09:13:39
  • Django用数据库表反向生成models类知识点详解

    2024-01-25 15:19:20
  • asp之家 网络编程 m.aspxhome.com