Python实现EM算法实例代码

作者:程序员大本营 时间:2021-05-06 03:02:26 

EM算法实例

通过实例可以快速了解EM算法的基本思想,具体推导请点文末链接。图a是让我们预热的,图b是EM算法的实例。

这是一个抛硬币的例子,H表示正面向上,T表示反面向上,参数θ表示正面朝上的概率。硬币有两个,A和B,硬币是有偏的。本次实验总共做了5组,每组随机选一个硬币,连续抛10次。如果知道每次抛的是哪个硬币,那么计算参数θ就非常简单了,如

下图所示:

Python实现EM算法实例代码

如果不知道每次抛的是哪个硬币呢?那么,我们就需要用EM算法,基本步骤为:

  1、给θ_AθA和θ_BθB一个初始值;

  2、(E-step)估计每组实验是硬币A的概率(本组实验是硬币B的概率=1-本组实验是硬币A的概率)。分别计算每组实验中,选择A硬币且正面朝上次数的期望值,选择B硬币且正面朝上次数的期望值;

  3、(M-step)利用第三步求得的期望值重新计算θ_AθA和θ_BθB;

  4、当迭代到一定次数,或者算法收敛到一定精度,结束算法,否则,回到第2步。

Python实现EM算法实例代码

计算过程详解:初始值θ_A^{(0)}θA(0)=0.6,θ_B^{(0)}θB(0)=0.5。

由两个硬币的初始值0.6和0.5,容易得出投掷出5正5反的概率是p_A=C^5_{10}*(0.6^5)*(0.4^5)pA=C105∗(0.65)∗(0.45),p_B=C_{10}^5*(0.5^5)*(0.5^5)pB=C105∗(0.55)∗(0.55), p_ApA/(p_ApA+p_BpB)=0.449, 0.45就是0.449近似而来的,表示第一组实验选择的硬币是A的概率为0.45。然后,0.449 * 5H = 2.2H ,0.449 * 5T = 2.2T ,表示第一组实验选择A硬币且正面朝上次数和反面朝上次数的期望值都是2.2,其他的值依次类推。最后,求出θ_A^{(1)}θA(1)=0.71,θ_B^{(1)}θB(1)=0.58。重复上述过程,不断迭代,直到算法收敛到一定精度为止。

这篇博客对EM算法的推导非常详细,链接如下:

https://blog.csdn.net/zhihua_oba/article/details/73776553

Python实现


#coding=utf-8
from numpy import *
from scipy import stats
import time
start = time.perf_counter()

def em_single(priors,observations):
"""
EM算法的单次迭代
Arguments
------------
priors:[theta_A,theta_B]
observation:[m X n matrix]

Returns
---------------
new_priors:[new_theta_A,new_theta_B]
:param priors:
:param observations:
:return:
"""
counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
theta_A = priors[0]
theta_B = priors[1]
#E step
for observation in observations:
 len_observation = len(observation)
 num_heads = observation.sum()
 num_tails = len_observation-num_heads
 #二项分布求解公式
 contribution_A = stats.binom.pmf(num_heads,len_observation,theta_A)
 contribution_B = stats.binom.pmf(num_heads,len_observation,theta_B)

weight_A = contribution_A / (contribution_A + contribution_B)
 weight_B = contribution_B / (contribution_A + contribution_B)
 #更新在当前参数下A,B硬币产生的正反面次数
 counts['A']['H'] += weight_A * num_heads
 counts['A']['T'] += weight_A * num_tails
 counts['B']['H'] += weight_B * num_heads
 counts['B']['T'] += weight_B * num_tails

# M step
new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
return [new_theta_A,new_theta_B]

def em(observations,prior,tol = 1e-6,iterations=10000):
"""
EM算法
:param observations :观测数据
:param prior:模型初值
:param tol:迭代结束阈值
:param iterations:最大迭代次数
:return:局部最优的模型参数
"""
iteration = 0;
while iteration < iterations:
 new_prior = em_single(prior,observations)
 delta_change = abs(prior[0]-new_prior[0])
 if delta_change < tol:
  break
 else:
  prior = new_prior
  iteration +=1
return [new_prior,iteration]

#硬币投掷结果
observations = array([[1,0,0,0,1,1,0,1,0,1],
     [1,1,1,1,0,1,1,1,0,1],
     [1,0,1,1,1,1,1,0,1,1],
     [1,0,1,0,0,0,1,1,0,0],
     [0,1,1,1,0,1,1,1,0,1]])
print (em(observations,[0.6,0.5]))
end = time.perf_counter()
print('Running time: %f seconds'%(end-start))

来源:https://www.pianshen.com/article/7293314043/

标签:python,em,算法
0
投稿

猜你喜欢

  • sql 常用技巧整理

    2011-11-03 17:10:14
  • asp压缩access数据库方法代码

    2008-08-08 12:22:00
  • 细数JavaScript 一个等号,两个等号,三个等号的区别

    2023-08-25 08:22:09
  • python 获取字典特定值对应的键的实现

    2022-07-01 19:25:21
  • jsp学习之scriptlet的使用方法详解

    2023-06-27 11:06:37
  • 解决IE6、IE7、Firefox兼容最简单的CSS Hack

    2007-10-14 10:51:00
  • Javascript世界的最大整数值

    2008-06-23 13:23:00
  • [多图]新:60个国外创意404页面设计

    2008-12-05 12:00:00
  • Python图像运算之图像掩膜直方图和HS直方图详解

    2023-03-01 03:01:45
  • Python实现读取txt文件中的数据并绘制出图形操作示例

    2021-07-21 17:01:15
  • 人民币的符号的正确表示法?一杠?两杠?¥还是¥呢?

    2010-03-24 12:21:00
  • ASP利用TCPIP.DNS组件实现域名IP查询

    2010-02-26 11:25:00
  • JavaScript几种弹窗事件的使用

    2023-08-24 15:59:08
  • python使用pyaudio录音和格式转化方式

    2023-11-07 19:30:03
  • Web标准在中国

    2008-11-26 11:27:00
  • 如何利用python正确地为图像添加高斯噪声

    2023-08-03 08:26:22
  • Mango Cache缓存管理库TinyLFU源码解析

    2023-09-02 12:27:51
  • Redux saga异步管理与生成器详解

    2023-07-24 02:53:52
  • php计算两个整数的最大公约数常用算法小结

    2023-11-20 00:29:01
  • 如何对Oracle8数据库进行维护?

    2009-11-20 18:01:00
  • asp之家 网络编程 m.aspxhome.com