隐马尔科夫模型(HMM)的无监督学习算法java实现(baum-welch迭代求解),包括串行以及并行实现
HMM的原理就不說(shuō)了,這里主要說(shuō)算法的實(shí)現(xiàn)。
?
實(shí)際實(shí)現(xiàn)起來(lái)并不是很困難,前提是你仔細(xì)看過hmm的原理,然后很多實(shí)現(xiàn)就照著公式寫出對(duì)應(yīng)的代碼,比如前向算法,后向算法,參數(shù)更新都是有明確的公式的,只需要對(duì)應(yīng)寫成代碼,這里需要提到2點(diǎn)技巧。
1,所有概率需要取對(duì)數(shù),這是因?yàn)橛械母怕蕦?shí)在是太小了,容易溢出,或者精度不夠。
2.對(duì)一個(gè)求和的式子取對(duì)數(shù)概率時(shí)需要用到一個(gè)技巧。下面直接貼出我寫的關(guān)于這個(gè)計(jì)算技巧的理解。
LogSum計(jì)算技巧
如果需要計(jì)算下面的式子:
其中α是一個(gè)概率,只知道 而不知道 ,此時(shí)如果直接計(jì)算會(huì)溢出,為了解決這個(gè)問題就可以用到這個(gè)logSum計(jì)算技巧。傳入?yún)?shù)是一個(gè)數(shù)組,每個(gè)元素為 ,找到其中的最大值
,根據(jù):
示例代碼:
這是一個(gè)java版本實(shí)現(xiàn)的無(wú)監(jiān)督HMM,包括學(xué)習(xí)算法和預(yù)測(cè)算法,算法沒有錯(cuò)誤,我已經(jīng)做過多次測(cè)試,但是由于HMM的訓(xùn)練算法就是EM算法,而EM算法對(duì)初值十分敏感,所以訓(xùn)練時(shí)必須給定一些先驗(yàn)條件,即需要給定HMM中的參數(shù)pi,A,B至少其中一個(gè),不然訓(xùn)練出來(lái)的參數(shù)將時(shí)一樣的,毫無(wú)意義。當(dāng)然無(wú)監(jiān)督的HMM效果依然不如監(jiān)督學(xué)習(xí)的HMM,我測(cè)試了一下分詞,給定了一個(gè)監(jiān)督學(xué)習(xí)HMM分詞的參數(shù)來(lái)訓(xùn)練無(wú)監(jiān)督的HMM,效果如下:
無(wú)監(jiān)督HMM效果:給定參數(shù)pi和B 參數(shù)已收斂.... 最終參數(shù): pi:[0.0, -2.1474836360090876E9, -2.147483633470365E9, -2.1474836334854264E9] A: [-2.1474836482889004E9, -2.2141536347013195, -0.115686913295864, -2.147483648337081E9] [-2.1474836479972153E9, -1.0239098431251064, -0.4450188807874582, -2.1474836480612097E9] [-0.7149451174254677, -2.1474836483350754E9, -2.147483648333808E9, -0.6718142750423626] [-0.4322715044602754, -2.1474836481090927E9, -2.1474836481949987E9, -1.0470634684331648] [原標(biāo)題, :, 日, 媒拍, 到, 了, 現(xiàn)場(chǎng), 罕見, 一幕, 據(jù), 日本, 新聞, 網(wǎng), (, N, NN, ), 9月, 8日, 報(bào)道, ,, 日前, ,, 日本, 海上, 自衛(wèi)隊(duì), 現(xiàn)役, 最大, 戰(zhàn)艦, 之一, 的, 直升, 機(jī)航, 母, “, 加賀, ”, 號(hào)在, 南, 海航, 行時(shí), ,, 遭多, 艘, 中國(guó), 海軍, 戰(zhàn)艦, 抵, 近跟, 蹤, 監(jiān)視, 。]監(jiān)督學(xué)習(xí)HMM: [原, 標(biāo)題, :, 日媒, 拍到, 了, 現(xiàn)場(chǎng), 罕見, 一幕, 據(jù), 日本, 新聞網(wǎng), (, NN, N)9月8日, 報(bào)道, ,, 日前, ,, 日本, 海上, 自衛(wèi)隊(duì), 現(xiàn)役, 最大, 戰(zhàn)艦, 之, 一, 的, 直升, 機(jī)航母, “, 加賀, ”, 號(hào), 在, 南海, 航行, 時(shí), ,, 遭多, 艘, 中國(guó), 海軍, 戰(zhàn)艦, 抵近, 跟蹤, 監(jiān)視, 。]雖然沒有指定參數(shù)A,但是可以看到學(xué)習(xí)出來(lái)的A還是有準(zhǔn)確性,比如B轉(zhuǎn)移到B的概率為0,B轉(zhuǎn)移到S的概率為0,M轉(zhuǎn)移到M的概率為0,M轉(zhuǎn)移到S的概率為0.....這和監(jiān)督學(xué)習(xí)的HMM一樣的。
這樣看來(lái)這個(gè)算法確實(shí)有效。
光從結(jié)果來(lái)看無(wú)監(jiān)督的HMM指定了pi和B參數(shù),整體效果還是差于監(jiān)督學(xué)習(xí)的HMM。測(cè)試語(yǔ)料只有人民日?qǐng)?bào)1998的分割語(yǔ)料。
由于原代碼比較長(zhǎng),并且本來(lái)是寫到我的開源項(xiàng)目中的,所以不簡(jiǎn)單是整合到一個(gè)類中就是所有代碼,還包含了一些依賴。
這里我整理出了串行版本的只包好2個(gè)依賴的代碼,供學(xué)習(xí)使用,由于訓(xùn)練中很多步驟都可以并行實(shí)現(xiàn),所以我并行了一些消耗時(shí)間的步驟,要比串行的快得多。過幾天我會(huì)更新到github上,完整的源碼請(qǐng)參考我的開源項(xiàng)目:GitHub - colin0000007/CONLP: 一個(gè)自然語(yǔ)言處理初學(xué)者可以參考的庫(kù),包含分詞,詞性標(biāo)注,命名實(shí)體識(shí)別,依存句法分析大多模型和算法都是自己實(shí)現(xiàn) 。a natural language processing library for beginners
代碼中需要用到語(yǔ)料以及HMM的參數(shù)A和B:
語(yǔ)料以及參數(shù)A和B.rar_免費(fèi)高速下載|百度網(wǎng)盤-分享無(wú)限制
下面是串行版本的代碼,:
package com.outsider.test;import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.logging.Logger; /*** * 無(wú)監(jiān)督學(xué)習(xí)的HMM實(shí)現(xiàn)* 少量數(shù)據(jù)建議串行* 大量數(shù)據(jù),幾十萬(wàn),百萬(wàn)甚至更高的數(shù)據(jù)強(qiáng)烈建議并行訓(xùn)練,性能是串行的好4倍以上* @author outsider*/ public class UnsupervisedFirstOrderGeneralHMM{private double precision = 1e-7;/*** 訓(xùn)練數(shù)據(jù)長(zhǎng)度*/private int sequenceLen;public Logger logger = Logger.getLogger(UnsupervisedFirstOrderGeneralHMM.class.getName());/**初始狀態(tài)概率**/protected double[] pi;/**轉(zhuǎn)移概率**/protected double[][] transferProbability1;/**發(fā)射概率**/protected double[][] emissionProbability;/**定義無(wú)窮大**/public static final double INFINITY = (double) -Math.pow(2, 31);/**狀態(tài)值集合的大小**/protected int stateNum;/**觀測(cè)值集合的大小**/protected int observationNum;public UnsupervisedFirstOrderGeneralHMM() {super();}public UnsupervisedFirstOrderGeneralHMM(int stateNum, int observationNum, double[] pi,double[][] transferProbability1, double[][] emissionProbability) {this.stateNum = stateNum;this.observationNum = observationNum;this.pi = pi;this.transferProbability1 = transferProbability1;this.emissionProbability = emissionProbability;}public UnsupervisedFirstOrderGeneralHMM(int stateNum, int observationNum) {this.stateNum = stateNum;this.observationNum = observationNum;initParameters();}/*** λ是HMM參數(shù)的總稱*//*** 訓(xùn)練方法* @param x 訓(xùn)練序列數(shù)據(jù)* @param maxIter 最大迭代次數(shù)* @param precision 精度*/public void train(int[] x, int maxIter, double precision) {this.sequenceLen = x.length;baumWelch(x, maxIter, precision);}public void train(int[] x) {this.sequenceLen = x.length;//不做概率歸一化}/*** baumWelch算法迭代求解* 迭代時(shí)存在這樣的現(xiàn)象:新參數(shù)和上一次的參數(shù)差反而會(huì)變大,但是到后面這個(gè)誤差值幾乎會(huì)收斂* 所以迭代終止的條件有2個(gè):* 1.達(dá)到最大迭代次數(shù)* 2.參數(shù)A,B,pi中的值相比上一次的最大誤差小于某個(gè)精度值則認(rèn)為收斂* 3.若1中給的精度值太大,則可能導(dǎo)致無(wú)法收斂,所以增加了一個(gè)條件,如果當(dāng)前迭代的誤差和上一次迭代的誤差小于某個(gè)值(這里給定1e-7),* 可以認(rèn)為收斂了。* @param x 觀測(cè)序列* @param maxIter 最大迭代次數(shù),如果傳入<=0的數(shù)則默認(rèn)為Integer.MAX_VALUE,相當(dāng)于不收斂就不跳出循環(huán)* @param precision 參數(shù)誤差的精度小于precision就認(rèn)為收斂*/protected void baumWelch(int[] x, int maxIter, double precision) {int iter = 0;double oldMaxError = 0;if(maxIter <= 0) {maxIter = Integer.MAX_VALUE;}//初始化各種參數(shù)double[][] alpha = new double[sequenceLen][stateNum];double[][] beta = new double[sequenceLen][stateNum];double[][] gamma = new double[sequenceLen][stateNum];double[][][] ksi = new double[sequenceLen][stateNum][stateNum];while(iter < maxIter) {logger.info("\niter"+iter+"...");long start = System.currentTimeMillis();//計(jì)算各種參數(shù),為更新模型參數(shù)做準(zhǔn)備,對(duì)應(yīng)EM中的E步calcAlpha(x, alpha);calcBeta(x, beta);calcGamma(x, alpha, beta, gamma);calcKsi(x, alpha, beta, ksi);//更新參數(shù),對(duì)應(yīng)EM中的M步double[][] oldA = generateOldA();//double[][] oldB = generateOldB();//double[] oldPi = pi.clone();updateLambda(x, gamma, ksi);//double maxError = calcError(oldA, oldPi, oldB);double maxError = calcError(oldA, null, null);logger.info("max_error:"+maxError);if(maxError < precision || (Math.abs(maxError-oldMaxError)) < this.precision) {logger.info("參數(shù)已收斂....");break;}oldMaxError = maxError;iter++;long end = System.currentTimeMillis();logger.info("本次迭代結(jié)束,耗時(shí):"+(end - start)+"毫秒");}logger.info("最終參數(shù):");logger.info("pi:"+Arrays.toString(pi));logger.info("A:");for(int i = 0; i < transferProbability1.length; i++) {logger.info(Arrays.toString(transferProbability1[i]));}}/*** 保存舊的參數(shù)A* @return*/protected double[][] generateOldA() {double[][] oldA = new double[stateNum][stateNum];for(int i = 0; i < stateNum; i++) {for(int j = 0; j < stateNum; j++) {oldA[i][j] = transferProbability1[i][j];}}return oldA;}/*** 保存舊的參數(shù)B* @return*/protected double[][] generateOldB() {double[][] oldB = new double[stateNum][observationNum];for(int i = 0; i < stateNum; i++) {for(int j = 0; j < observationNum; j++) {oldB[i][j] = emissionProbability[i][j];}}return oldB;}/*** 暫時(shí)只計(jì)算參數(shù)A的誤差* 發(fā)現(xiàn)計(jì)算B和pi會(huì)發(fā)現(xiàn)參數(shù)誤差越來(lái)越大的現(xiàn)象,基本不能收斂* @param old* @return*/protected double calcError(double[][] oldA, double[] oldPi, double[][] oldB) {double maxError = 0;for(int i =0 ; i < stateNum; i++) {/*double tmp1 = Math.abs(pi[i] - oldPi[i]);maxError = tmp1 > maxError ? tmp1 : maxError;*/for(int j =0; j < stateNum; j++) {double tmp = Math.abs(oldA[i][j] - transferProbability1[i][j]);maxError = tmp > maxError ? tmp : maxError;}/*for(int k =0; k < observationNum; k++) {double tmp2 = Math.abs(emissionProbability[i][k] - oldB[i][k]);maxError = tmp2 > maxError ? tmp2 : maxError;}*/}return maxError;}/*** 概率初始化為0*/public void initParameters() {//初始概率隨機(jī)初始化pi = new double[stateNum];transferProbability1 = new double[stateNum][stateNum];emissionProbability = new double[stateNum][observationNum];//概率初始化為0for(int i = 0; i < stateNum; i++) {pi[i] = INFINITY;for(int j = 0; j < stateNum; j++) {transferProbability1[i][j] = INFINITY;}for(int k = 0; k < observationNum; k++) {emissionProbability[i][k] = INFINITY;}}}/*** 數(shù)組求和* @param arr* @return*/public static double sum(double[] arr) {double sum = 0;for(int i = 0; i < arr.length;i++) {sum += arr[i];}return sum;}/*** 隨機(jī)初始化參數(shù)PI*/public void randomInitPi() {for(int i = 0; i < stateNum; i++) {pi[i] = Math.random() * 100;}//log歸一化double sum = Math.log(sum(pi));for(int i =0; i < stateNum; i++) {if(pi[i] == 0) {pi[i] = INFINITY;continue;}pi[i] = Math.log(pi[i]) - sum;}}/*** 隨機(jī)初始化參數(shù)A*/public void randomInitA() {for(int i = 0; i < stateNum; i++) {for(int j = 0; j < stateNum; j++) {transferProbability1[i][j] = Math.random()*100;;}double sum = Math.log(sum(transferProbability1[i]));for(int k = 0; k < stateNum; k++) {if(transferProbability1[i][k] == 0) {transferProbability1[i][k] = INFINITY;continue;}transferProbability1[i][k] = Math.log(transferProbability1[i][k]) - sum;}}}/*** 隨機(jī)初始化參數(shù)B*/public void randomInitB() {for(int i = 0; i < stateNum; i++) {for(int j = 0; j < observationNum; j++) {emissionProbability[i][j] = Math.random()*100;;}double sum = Math.log(sum(emissionProbability[i]));for(int k = 0; k < observationNum; k++) {if(emissionProbability[i][k] == 0) {emissionProbability[i][k] = INFINITY;continue;}emissionProbability[i][k] = Math.log(emissionProbability[i][k]) - sum;}}}/*** 隨機(jī)初始化所有參數(shù)*/public void randomInitAllParameters() {randomInitA();randomInitB();randomInitPi();}/*** 前向算法,根據(jù)當(dāng)前參數(shù)λ計(jì)算α* α是一個(gè)序列長(zhǎng)度*狀態(tài)長(zhǎng)度的矩陣* 已檢測(cè),應(yīng)該沒有問題*/protected void calcAlpha(int[] x, double[][] alpha) {logger.info("計(jì)算alpha...");long start = System.currentTimeMillis();//double[][] alpha = new double[sequenceLen][stateNum];//alpha t=0初始值for(int i = 0; i < stateNum; i++) {alpha[0][i] = pi[i] + emissionProbability[i][x[0]];}double[] logProbaArr = new double[stateNum];for(int t = 1; t < sequenceLen; t++) {for(int i = 0; i < stateNum; i++) {for(int j = 0; j < stateNum; j++) {logProbaArr[j] = (alpha[t -1][j] + transferProbability1[j][i]);}alpha[t][i] = logSum(logProbaArr) + emissionProbability[i][x[t]];}}long end = System.currentTimeMillis();logger.info("計(jì)算結(jié)束...耗時(shí):"+ (end - start) +"毫秒");//return alpha;}/*** 后向算法,根據(jù)當(dāng)前參數(shù)λ計(jì)算β* * @param x*/protected void calcBeta(int[] x, double[][] beta) {logger.info("計(jì)算beta...");long start = System.currentTimeMillis();//double[][] beta = new double[sequenceLen][stateNum];//初始概率beta[T][i] = 1for(int i = 0; i < stateNum; i++) {beta[sequenceLen-1][i] = 1;}double[] logProbaArr = new double[stateNum];for(int t = sequenceLen -2; t >= 0; t--) {for(int i = 0; i < stateNum; i++) {for(int j = 0; j < stateNum; j++) {logProbaArr[j] = transferProbability1[i][j] + emissionProbability[j][x[t+1]] +beta[t + 1][j];}beta[t][i] = logSum(logProbaArr);}}long end = System.currentTimeMillis();logger.info("計(jì)算結(jié)束...耗時(shí):"+ (end - start) +"毫秒");//return beta;}/*** 根據(jù)當(dāng)前參數(shù)λ計(jì)算ξ* @param x 觀測(cè)結(jié)點(diǎn)* @param alpha 前向概率* @param beta 后向概率*/protected void calcKsi(int[] x, double[][] alpha, double[][] beta, double[][][] ksi) {logger.info("計(jì)算ksi...");long start = System.currentTimeMillis();//double[][][] ksi = new double[sequenceLen][stateNum][stateNum];double[] logProbaArr = new double[stateNum * stateNum];for(int t = 0; t < sequenceLen -1; t++) {int k = 0;for(int i = 0; i < stateNum; i++) {for(int j = 0; j < stateNum; j++) {ksi[t][i][j] = alpha[t][i] + transferProbability1[i][j] +emissionProbability[j][x[t+1]]+beta[t+1][j];logProbaArr[k++] = ksi[t][i][j];}}double logSum = logSum(logProbaArr);//分母for(int i = 0; i < stateNum; i++) {for(int j = 0; j < stateNum; j++) {ksi[t][i][j] -= logSum;//分子除分母}}}long end = System.currentTimeMillis();logger.info("計(jì)算結(jié)束...耗時(shí):"+ (end - start) +"毫秒");//return ksi;}/*** 根據(jù)當(dāng)前參數(shù)λ,計(jì)算γ* @param x*/protected void calcGamma(int[] x, double[][] alpha, double[][] beta, double[][] gamma) {logger.info("計(jì)算gamma...");long start = System.currentTimeMillis();//double[][] gamma = new double[sequenceLen][stateNum];for(int t = 0; t < sequenceLen; t++) {//分母需要求LogSumfor(int i = 0; i < stateNum; i++) {gamma[t][i] = alpha[t][i] + beta[t][i];}double logSum = logSum(gamma[t]);//分母部分for(int j = 0; j < stateNum; j++) {gamma[t][j] = gamma[t][j] - logSum;}}long end = System.currentTimeMillis();logger.info("計(jì)算結(jié)束...耗時(shí):"+ (end - start) +"毫秒");//return gamma;}/*** 更新參數(shù)*/protected void updateLambda(int[] x ,double[][] gamma, double[][][] ksi) {//順序可以顛倒updatePi(gamma);updateA(ksi, gamma);updateB(x, gamma);}/*** 更新參數(shù)pi* @param gamma*/public void updatePi(double[][] gamma) {//更新HMM中的參數(shù)pifor(int i = 0; i < stateNum; i++) {pi[i] = gamma[0][i];}}/*** 更新參數(shù)A* @param ksi* @param gamma*/protected void updateA(double[][][] ksi, double[][] gamma) {logger.info("更新參數(shù)轉(zhuǎn)移概率A...");由于在更新A都要用到對(duì)不同狀態(tài)的前T-1的gamma值求和,所以這里先算double[] gammaSum = new double[stateNum];double[] tmp = new double[sequenceLen -1];for(int i = 0; i < stateNum; i++) {for(int t = 0; t < sequenceLen -1; t++) {tmp[t] = gamma[t][i];}gammaSum[i] = logSum(tmp);}long start1 = System.currentTimeMillis();//更新HMM中的參數(shù)Adouble[] ksiLogProbArr = new double[sequenceLen - 1];for(int i = 0; i < stateNum; i++) {for(int j = 0; j < stateNum; j++) {for(int t = 0; t < sequenceLen -1; t++) {ksiLogProbArr[t] = ksi[t][i][j];}transferProbability1[i][j] = logSum(ksiLogProbArr) - gammaSum[i];}}long end1 = System.currentTimeMillis();logger.info("更新完畢...耗時(shí):"+(end1 - start1)+"毫秒");}/*** 更新參數(shù)B* @param x* @param gamma*/protected void updateB(int[] x, double[][] gamma) {//下面需要用到gamma求和為了減少重復(fù)計(jì)算,這里直接先計(jì)算//由于在更新B時(shí)都要用到對(duì)不同狀態(tài)的所有g(shù)amma值求和,所以這里先算double[] gammaSum2 = new double[stateNum];double[] tmp2 = new double[sequenceLen];for(int i = 0; i < stateNum; i++) {for(int t = 0; t < sequenceLen; t++) {tmp2[t] = gamma[t][i];}gammaSum2[i] = logSum(tmp2);}logger.info("更新狀態(tài)下分布概率B...");long start2 = System.currentTimeMillis();ArrayList<Double> valid = new ArrayList<Double>();for(int i = 0; i < stateNum; i++) {for(int k = 0; k < observationNum; k++) {valid.clear();//由于這里沒有初始化造成了計(jì)算出錯(cuò)的問題for(int t = 0; t < sequenceLen; t++) {if(x[t] == k) {valid.add(gamma[t][i]);}}//B[i][k],i狀態(tài)下k的分布為概率0,if(valid.size() == 0) {emissionProbability[i][k] = INFINITY;continue;}//對(duì)分子求logSumdouble[] validArr = new double[valid.size()];for(int q = 0; q < valid.size(); q++) {validArr[q] = valid.get(q);}double validSum = logSum(validArr);//分母的logSum已經(jīng)在上面做了emissionProbability[i][k] = validSum - gammaSum2[i];}}long end2 = System.currentTimeMillis();logger.info("更新完畢...耗時(shí):"+(end2 - start2)+"毫秒");}/*** logSum計(jì)算技巧* @param tmp* @return*/public double logSum(double[] logProbaArr) {if(logProbaArr.length == 0) {return INFINITY;}double max = max(logProbaArr);double result = 0;for(int i = 0; i < logProbaArr.length; i++) {result += Math.exp(logProbaArr[i] - max);}return max + Math.log(result);}/*** 設(shè)置先驗(yàn)概率pi* 必須傳入取對(duì)數(shù)后的概率* @param pi*/public void setPriorPi(double[] pi){this.pi = pi;}/*** 設(shè)置先驗(yàn)轉(zhuǎn)移概率A* 必須傳入取對(duì)數(shù)的概率* @param trtransferProbability1*/public void setPriorTransferProbability1(double[][] trtransferProbability1){this.transferProbability1 = trtransferProbability1;}/*** 設(shè)置先驗(yàn)狀態(tài)下的觀測(cè)分布概率,B* 必須傳入取對(duì)數(shù)的概率* @param emissionProbability*/public void setPriorEmissionProbability(double[][] emissionProbability) {this.emissionProbability = emissionProbability;}public static double max(double[] arr) {double max = arr[0];for(int i = 1; i < arr.length;i++) {max = arr[i] > max ? arr[i] : max;}return max;}/*** 維特比解碼* @param O 觀測(cè)序列,輸入的是經(jīng)過編碼處理的,而不是原始數(shù)據(jù),* 比如,如果序列是字符串,那么輸入必須是一系列的字符的編碼而不是字符本身* @return 返回預(yù)測(cè)結(jié)果,*/public int[] verterbi(int[] O) {double[][] deltas = new double[O.length][this.stateNum];//保存deltas[t][i]的值是由上一個(gè)哪個(gè)狀態(tài)產(chǎn)生的int[][] states = new int[O.length][this.stateNum];//初始化deltas[0][]for(int i = 0;i < this.stateNum; i++) {deltas[0][i] = pi[i] + emissionProbability[i][O[0]];}//計(jì)算deltasfor(int t = 1; t < O.length; t++) {for(int i = 0; i < this.stateNum; i++) {deltas[t][i] = deltas[t-1][0]+transferProbability1[0][i];for(int j = 1; j < this.stateNum; j++) {double tmp = deltas[t-1][j]+transferProbability1[j][i];if (tmp > deltas[t][i]) {deltas[t][i] = tmp;states[t][i] = j;}}deltas[t][i] += emissionProbability[i][O[t]];}}//回溯找到最優(yōu)路徑int[] predict = new int[O.length];double max = deltas[O.length-1][0];for(int i = 1; i < this.stateNum; i++) {if(deltas[O.length-1][i] > max) {max = deltas[O.length-1][i];predict[O.length-1] = i; }}for(int i = O.length-2;i >= 0;i-- ) {predict[i] = states[i+1][predict[i+1]];}return predict;}//測(cè)試public static void main(String[] args) {UnsupervisedFirstOrderGeneralHMM hmm = new UnsupervisedFirstOrderGeneralHMM(4, 65536);//關(guān)閉日志打印//CONLPLogger.closeLogger(hmm.logger);//由于是監(jiān)督學(xué)習(xí)的語(yǔ)料所以這里需要去掉其中的分隔符String path = "src/pku_training.splitBy2space.utf8";String data = IOUtils.readText(path, "utf-8");String[] d2 = data.split(" ");StringBuilder sb = new StringBuilder();for(String word : d2) {sb.append(word);}data = sb.toString();//訓(xùn)練數(shù)據(jù)int[] x = SegmentationUtils.str2int(data);//由于串行很慢,可以只取訓(xùn)練數(shù)據(jù)的前10000個(gè)來(lái)訓(xùn)練int[] minX = new int[10000];System.arraycopy(x, 0, minX, 0, 10000);//訓(xùn)練之前設(shè)置先驗(yàn)概率,必須設(shè)置,EM對(duì)初始值敏感,如果不設(shè)置默認(rèn)為都為0,所有參數(shù)都將一樣,沒有意義//如果只給了其中一些參數(shù)的先驗(yàn)值,可以隨機(jī)初始化其他參數(shù),例如//hmm.randomInitA();//hmm.randomInitB();//hmm.randomInitPi();//hmm.randomInitAllParameters();//設(shè)置先驗(yàn)信息至少設(shè)置參數(shù)pi,A,B中的一個(gè)hmm.setPriorPi(new double[] {-1.138130826175848, -2.632826946498266, -1.138130826175848, -1.2472622308278396});hmm.setPriorTransferProbability1((double[][]) IOUtils.readObject("src/A"));hmm.setPriorEmissionProbability((double[][]) IOUtils.readObject("src/B"));//開始訓(xùn)練hmm.train(minX, -1, 0.5);String str = "原標(biāo)題:日媒拍到了現(xiàn)場(chǎng)罕見一幕" + "據(jù)日本新聞網(wǎng)(NNN)9月8日?qǐng)?bào)道,日前,日本海上自衛(wèi)隊(duì)現(xiàn)役最大戰(zhàn)艦之一的直升機(jī)航母“加賀”號(hào)在南海航行時(shí),遭多艘中國(guó)海軍戰(zhàn)艦抵近跟蹤監(jiān)視。" ; //將詞轉(zhuǎn)換為對(duì)應(yīng)的Unicode碼int[] O = SegmentationUtils.str2int(str);int[] predict = hmm.verterbi(O);System.out.println(Arrays.toString(predict));String[] res = SegmentationUtils.decode(predict, str);System.out.println(Arrays.toString(res));} }依賴IoUtils:
package com.outsider.test;import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStreamWriter; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.List;public class IOUtils {public static String readTextWithLineCheckBreak(String path, String encoding) {return readText(path, encoding, "\n");}/*** 讀取文本文件,返回整個(gè)字符串,不包括換行符號(hào)* @param path 文件路徑* @param encoding 編碼,傳入null或者空串使用默認(rèn)編碼* @return*/public static String readText(String path, String encoding) {return readText(path, encoding, null);}/*** 讀取文本,指定每一行末尾符號(hào)* @param path* @param encoding* @param lineEndStr* @return*/public static String readText(String path, String encoding, String lineEndStr) {try {if(lineEndStr == null) {lineEndStr = "";}BufferedReader reader = null;if((!encoding.trim().equals(""))&&encoding!=null) {reader = new BufferedReader(new InputStreamReader(new FileInputStream(path),encoding));} else {reader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));}String s="";StringBuilder sb = new StringBuilder();while((s=reader.readLine())!=null) {sb.append(s+lineEndStr);}reader.close();return sb.toString();} catch (UnsupportedEncodingException e) {e.printStackTrace();} catch (FileNotFoundException e) {e.printStackTrace();} catch (IOException e) {e.printStackTrace();}return null;}/*** 讀取文本文件,返回整個(gè)字符串,不包括換行符號(hào)* @param path 文件路徑* @param encoding 編碼,傳入null或者空串使用默認(rèn)編碼* @param addNewLine 是否加換行符* @return*/public static List<String> readTextAndReturnLinesCheckLineBreak(String path, String encoding, boolean addNewLine) {try {String lineBreak;if(addNewLine) {lineBreak = "\n";} else {lineBreak = "";}BufferedReader reader = null;if((!encoding.trim().equals(""))&&encoding!=null) {reader = new BufferedReader(new InputStreamReader(new FileInputStream(path),encoding));} else {reader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));}String s="";List<String> list = new ArrayList<>();while((s=reader.readLine())!=null) {list.add(s+lineBreak);}reader.close();return list;} catch (UnsupportedEncodingException e) {e.printStackTrace();} catch (FileNotFoundException e) {e.printStackTrace();} catch (IOException e) {e.printStackTrace();}return null;}public static List<String> readTextAndReturnLines(String path, String encoding){return readTextAndReturnLinesCheckLineBreak(path, encoding, false);}/*** 讀取文本的每一行* 并且返回?cái)?shù)組形式* @param path* @param encoding* @return*/public static String[] readTextAndReturnLinesOfArray(String path, String encoding){List<String> lines = readTextAndReturnLines(path, encoding);String[] arr = new String[lines.size()];lines.toArray(arr);return arr;}/*** 寫入文本文件* @param data* @param path* @param encoding*/public static void writeTextData2File(String data,String path,String encoding) {try {BufferedWriter writer = null;if((!encoding.trim().equals(""))&&encoding!=null) {writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path),encoding));} else {writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path)));}writer.write(data);writer.close();} catch (UnsupportedEncodingException e) {e.printStackTrace();} catch (FileNotFoundException e) {e.printStackTrace();} catch (IOException e) {e.printStackTrace();}}/*** 把對(duì)象寫入文件* @param path* @param object*/public static void writeObject2File(String path, Object object) {try {ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(path));out.writeObject(object);out.close();} catch (Exception e) {e.printStackTrace();} }/*** 讀取對(duì)象* @param path* @return*/public static Object readObject(String path) {try {ObjectInputStream in = new ObjectInputStream(new FileInputStream(path));return in.readObject();} catch (Exception e) {e.printStackTrace();}return null; }}依賴的SegmentationUtils:
package com.outsider.test;import java.util.ArrayList; import java.util.List;public class SegmentationUtils {/*** 將字符串?dāng)?shù)組的每一個(gè)字符串中的字符直接轉(zhuǎn)換為Unicode碼* @param strs 字符串?dāng)?shù)組* @return Unicode值*/public static List<int[]> strs2int(String[] strs) {List<int[]> res = new ArrayList<>(strs.length);for(int i = 0; i < strs.length;i++) {int[] O = new int[strs[i].length()];for(int j = 0; j < strs[i].length();j++) {O[j] = strs[i].charAt(j);}res.add(O);}return res;}public static int[] str2int(String str) {return strs2int(new String[] {str}).get(0);}/*** 根據(jù)預(yù)測(cè)結(jié)果解碼* BEMS 0123* @param predict 預(yù)測(cè)結(jié)果* @param sentence 句子* @return*/public static String[] decode(int[] predict, String sentence) {List<String> res = new ArrayList<>();char[] chars = sentence.toCharArray();for(int i = 0; i < predict.length;i++) {if(predict[i] == 0 || predict[i] == 1) {int a = i;while(predict[i] != 2) {i++;if(i == predict.length) {break;}}int b = i;if(b == predict.length) {b--;}res.add(new String(chars,a,b-a+1));} else {res.add(new String(chars,i,1));}}String[] s = new String[res.size()];return res.toArray(s);} }總結(jié)
以上是生活随笔為你收集整理的隐马尔科夫模型(HMM)的无监督学习算法java实现(baum-welch迭代求解),包括串行以及并行实现的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 2020最新版C/C++学习路线图--游
- 下一篇: iOS 录屏大师启动页广告