Pytorch中的 torch.distributions库详解

作者:pengjh24 时间:2021-05-17 22:26:47 

Pytorch torch.distributions库

包介绍

torch.distributions包包含可参数化的概率分布和采样函数。 这允许构建用于优化的随机计算图和随机梯度估计器。

不可能通过随机样本直接反向传播。 但是,有两种主要方法可以创建可以反向传播的代理函数。

这些是

评分函数估计量 score function estimato
似然比估计量 likelihood ratio estimator
REINFORCE
路径导数估计量 pathwise derivative estimator
REINFORCE 通常被视为强化学习中策略梯度方法的基础,

路径导数估计器常见于变分自编码器的重新参数化技巧中。

虽然评分函数只需要样本 f(x)的值,但路径导数需要导数 f'(x)。

本文重点讲解Pytorch中的 torch.distributions库。

pytorch 的 torch.distributions 中可以定义正态分布:

import torch
from torch.distributions import  Normal
mean=torch.Tensor([0,2])
normal=Normal(mean,1)

sample()就是直接在定义的正太分布(均值为mean,标准差std是1)上采样:

result = normal.sample()
print("sample():",result)

输出:

sample(): tensor([-1.3362,  3.1730])

rsample()不是在定义的正太分布上采样,而是先对标准正太分布 N(0,1) 进行采样,然后输出: mean + std × 采样值

result = normal.rsample()
print("rsample():",result)

输出:

rsample: tensor([ 0.0530,  2.8396])

log_prob(value) 是计算value在定义的正态分布(mean,1)中对应的概率的对数,正太分布概率密度函数是:

Pytorch中的 torch.distributions库详解

对其取对数可得:

Pytorch中的 torch.distributions库详解

这里我们通过对数概率还原其对应的真实概率:

print("result log_prob:",normal.log_prob(result).exp())

输出:

result log_prob: tensor([ 0.1634,  0.2005])

来源:https://blog.csdn.net/qq_38789531/article/details/104950940

标签:Pytorch,torch,distributions,库
0
投稿

猜你喜欢

  • PHP getName()函数讲解

    2023-06-06 08:28:25
  • jquery的$(document).ready()和onload的加载顺序

    2023-08-23 18:57:40
  • pytorch中forwod函数在父类中的调用方式解读

    2023-04-27 11:12:25
  • kali2021.4a使用virtualenv安装angr的详细过程

    2022-10-15 12:39:00
  • python下的opencv画矩形和文字注释的实现方法

    2022-12-26 22:27:17
  • Go语言使用MySql的方法

    2024-01-20 04:09:25
  • Python模拟登录requests.Session应用详解

    2023-08-04 08:40:38
  • 如何控制弹出一个NTLM验证窗口?

    2009-12-16 19:01:00
  • 奇妙的Javascript图片放大镜

    2024-04-30 08:51:22
  • python Scrapy框架原理解析

    2022-08-07 06:17:20
  • Python OrderedDict字典排序方法详解

    2022-01-07 13:32:09
  • Python设计模式编程中Adapter适配器模式的使用实例

    2023-11-16 10:02:15
  • FCKeditor编辑器实战技巧

    2007-10-08 21:13:00
  • Pandas.DataFrame时间序列数据处理的实现

    2022-09-20 08:43:41
  • echo(),print(),print_r()之间的区别?

    2023-11-15 08:52:42
  • Python字符串拼接的4种方法实例

    2023-01-30 18:57:15
  • PHP常用字符串函数用法实例总结

    2024-05-11 10:01:28
  • Python PIL读取的图像发生自动旋转的实现方法

    2022-05-01 20:29:26
  • golang时间及时间戳的获取转换

    2024-05-05 09:26:27
  • python机器人行走步数问题的解决

    2023-12-24 23:26:05
  • asp之家 网络编程 m.aspxhome.com