生活随笔
收集整理的這篇文章主要介紹了
EM算法实践
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
學習1
學習2
一、Basic EM算法
np
.random
.multivariate_normal(mean
,convirance
,size)生成多元正態分布()
判斷預估的分布與實際分布的順序是否相同,需要用到樣本的標簽及數據特征。
程序的數據是男女身高,女生標簽是0,男生是1。我們有先驗知識,男生的身高比女生高,所以這個通過比較模型的兩個均值,即可預測的那個分布是女生,哪個是男生。
通過flag表示,女生是第一分布,flag=0.cmp_point = mpl.colors.ListedColormap(['#22B14C','#ED1C24'])畫散點圖時,通過c,cmap參數標記不同類別的點
import numpy
as np
from scipy
.stats
import multivariate_normal
from sklearn
.mixture
import GaussianMixture
from sklearn
.model_selection
import train_test_split
from mpl_toolkits
.mplot3d
import Axes3D
import matplotlib
as mpl
from matplotlib
import patches
as mpatches
import matplotlib
.pyplot
as plt
from sklearn
.metrics
.pairwise
import pairwise_distances_argmin
import pandas
as pdnp
.random
.seed
(0)
data
= pd
.read_csv
('../HeightWeight.csv')
print(data
.head
())
feature
= data
[['Height(cm)','Weight(kg)']]
label
= data
['Sex']
train_x
,test_x
,train_y
,test_y
= train_test_split
(feature
,label
,test_size
=0.3)
print(train_x
.shape
)
print(train_y
.shape
)
gmm
= GaussianMixture
(n_components
=2,covariance_type
='full',max_iter
=100)
gmm
.fit
(train_x
)
print('均值:\n',gmm
.means_
)mu1
,mu2
= gmm
.means_
cov1
,cov2
= gmm
.covariances_
norm1
= multivariate_normal
(mu1
,cov1
)
norm2
= multivariate_normal
(mu2
,cov2
)
tau1
= norm1
.pdf
(train_x
)
tau2
= norm2
.pdf
(train_x
)
flag
= 0
if gmm
.means_
[0][0]<gmm
.means_
[1][0]:c1
= tau1
> tau2
else:flag
=1c1
= tau1
< tau2
c2
= ~c1
tau1_test
= norm1
.pdf
(test_x
)
tau2_test
= norm2
.pdf
(test_x
)
if flag
:c1_test
= tau1_test
< tau2_test
else:c1_test
= tau1_test
> tau2_testc2_test
= ~c1_testheight_min
,height_max
= data
['Height(cm)'].min(),data
['Height(cm)'].max()
weight_min
,weight_max
= data
['Weight(kg)'].min(),data
['Weight(kg)'].max()x
= np
.linspace
(height_min
-0.5,height_max
+0.5,300)
y
= np
.linspace
(weight_min
-0.5,weight_max
+0.5,300)
xx
,yy
= np
.meshgrid
(x
,y
)
grid_test
= np
.stack
((xx
.flat
,yy
.flat
),axis
=1)
grid_predict
= gmm
.predict
(grid_test
)
cmp_point
= mpl
.colors
.ListedColormap
(['#22B14C','#ED1C24'])
cmp_bkg
= mpl
.colors
.ListedColormap
(['#B0E0E6','#FFC0CB'])plt
.pcolormesh
(xx
,yy
,grid_predict
.reshape
(xx
.shape
),cmap
=cmp_bkg
)
plt
.xlabel
('Height(cm)')
plt
.ylabel
('Weight(cm)')print(train_x
.head
())
print(train_x
.columns
)
print('*'*20)
print(train_x
['Height(cm)'].shape
)
print(train_y
.shape
)
print('*'*20)
plt
.scatter
(train_x
['Height(cm)'],train_x
['Weight(kg)'],c
=train_y
,marker
='o',cmap
=cmp_point
)
plt
.scatter
(test_x
['Height(cm)'],test_x
['Weight(kg)'],c
=c2_test
,marker
='^',s
= 60,cmap
=cmp_point
)
patchs
= [mpatches
.Patch
(color
='#B0E0E6', label
='girl'),mpatches
.Patch
(color
='#FFC0CB', label
='boy'),]
plt
.legend
(handles
=patchs
, fancybox
=True, framealpha
=0.8)plt
.show
()
train_acc
= np
.mean
(train_y
== c2
)
test_acc
= np
.mean
(test_y
== c2_test
)
print('trian acc: ',train_acc
)
print('test acc: ',test_acc
)
二、GMM參數
方差類型
covariance_type
= ('spherical', 'diag', 'tied', 'full')
BIC
BIC=kln(n)?LBIC=kln(n) -LBIC=kln(n)?L
其中,k為模型參數個數,n為樣本數量,L為似然函數。kln(n)懲罰項在維數過大且訓練樣本數據相對較少的情況下,可以有效避免出現維度災難現象
三、DPGMM
DPGMM對于簇的個數選個比較有用
dpgmm
= BayesianGaussianMixture
(n_components
=n_components
, covariance_type
='full', max_iter
=1000, n_init
=5,weight_concentration_prior_type
='dirichlet_process',weight_concentration_prior
=0.1)
總結
以上是生活随笔為你收集整理的EM算法实践的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。