Fisher线性判别分析Fisher Linear Distrimination
生活随笔
收集整理的這篇文章主要介紹了
Fisher线性判别分析Fisher Linear Distrimination
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
Fisher線性判別分析是一種線性分類方法,它的主要思想是:是類內的方差小,類均值之間相差比較大。(類間大,類內小)
以兩個類的分類為例:
將兩個類由在x1,x2上投影到向量u 上,這樣由二維轉到了一維,然后將兩類從兩團點的中間分開。
如果要使類間相差大的話,那么每個類的平均數之間也會相差大,設分別為加號點和減號點的平均值,那么投影后,他們距離的平方,也就是盡可能大。
如果要使類內方差小的話,那么兩個類投影到直線(向量)上后,他們的點分別為,表示兩個協方差的投影。
所以他們的和也要盡可能小
因此把大的作為分子,小的作為分母,他們相除的整體就是越大越好,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?設? ? ??
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ??
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?則? ? ??
對J(u)進行求導,則有
令導數等于0,? ? ? 得到? ? ? ? ??
括號里的可以用一個縮放值來代替
則? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
因為和在相乘之后變為常數
因此最終? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?
就是我們要投影的向量。
import matplotlib.pyplot as plt import numpy as npdef gauss2D(x, m, C):Ci = np.linalg.inv(C) #求矩陣的逆dC = np.linalg.det(C) #求矩陣的行列式num = np.exp(-0.5 * np.dot((x-m).T, np.dot(Ci,(x-m))))den = 2 * np.pi * (dC**0.5) #計算矩陣的密度函數return num/dendef twoDGaussianPlot(nx, ny, m, C):x = np.linspace(-6, 6, nx)y = np.linspace(-6, 6, ny)X, Y = np.meshgrid(x, y, indexing='ij')Z = np.zeros([nx,ny])for i in range(nx):for j in range(ny):xvec = np.array([X[i,j], Y[i,j]])Z[i,j] = gauss2D(xvec, m, C)return X, Y, ZX = np.random.randn(200, 2) C1 = np.array([[2,1],[1,2]]) C2 = np.array([[2,1],[1,2]]) m1 = np.array([0, 3]) m2 = np.array([3,2.5]) A = np.linalg.cholesky(C1)Y1 = X @ A.T + m1 Y2 = X @ A.T + m2plt.figure(1) plt.scatter(Y1[:,0], Y1[:,1], c='c', s=4) plt.scatter(Y2[:,0], Y2[:,1], c='m', s=4)Xp, Yp, Zp = twoDGaussianPlot(40,50,m1,C1) plt.contour(Xp, Yp, Zp, 5)Xp2, Yp2, Zp2 = twoDGaussianPlot(40,50,m2,C2) plt.contour(Xp2, Yp2, Zp2, 5)uF = np.linalg.inv(C1 + C2)@(m1-m2) print(uF) #ax.arrow(0, 0, *(uF*10), color='b', linewidth=2.0, head_width=0.20, head_length=0.25) plt.arrow(0, 0, *(uF), color='b', linewidth=2.0, head_width=0.30, head_length=0.35)plt.axis('equal') plt.grid() plt.xlim([-6,6]) plt.ylim([-5,8])plt.savefig('density graph.png')yp1 = Y1 @ uF yp2 = Y2 @ uFplt.figure(2) plt.rcParams.update({'font.size':16}) plt.hist(yp1, bins=40) plt.hist(yp2, bins=40) plt.savefig('histogramprojections.png')總結
以上是生活随笔為你收集整理的Fisher线性判别分析Fisher Linear Distrimination的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 别再说你不会!java嵌入式开发教程
- 下一篇: mysql 逗号_在MySQL字段中使用