轉載自
https://blog.csdn.net/Leafage_M/article/details/80137305
用一句話總結這篇博客的內容就是:
對于當前n條數據,相鄰求平均值,得到n-1個分割值,要點如下:
①連續數值特征的熵計算就是對上面的n-1個分割值不停嘗試,
嘗試得到最佳分割值,利用分割值兩側的數據來計算條件熵
進而最終計算最大熵增益.
②如果當前同時存在離散值和連續值特征,那么連續值取最大信息增益熵,來和離散值特征進行比較,然后選擇最佳分割特征.
③如果當前只剩下連續值特征,那么每次分割都選擇讓熵增益最大的分割值作為劃分特征.
所以也印證了周志華<機器學習>上面的一段話,
決策樹中,
離散數值特征只能用一次,
連續數值特征能使用多次.
轉載的鏈接中python3.0的,修改為python2.7如下:
top.py
#-*- coding:utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import collections
from math import log
import operator
import treePlotter
import pandas as pddef createDataSet():"""西瓜數據集3.0:return:"""dataSet = [# 1['青綠', '蜷縮', '濁響', '清晰', '凹陷', '硬滑', 0.697, 0.460, '好瓜'],# 2['烏黑', '蜷縮', '沉悶', '清晰', '凹陷', '硬滑', 0.774, 0.376, '好瓜'],# 3['烏黑', '蜷縮', '濁響', '清晰', '凹陷', '硬滑', 0.634, 0.264, '好瓜'],# 4['青綠', '蜷縮', '沉悶', '清晰', '凹陷', '硬滑', 0.608, 0.318, '好瓜'],# 5['淺白', '蜷縮', '濁響', '清晰', '凹陷', '硬滑', 0.556, 0.215, '好瓜'],#######3# 6['青綠', '稍蜷', '濁響', '清晰', '稍凹', '軟粘', 0.403, 0.237, '好瓜'],# 7['烏黑', '稍蜷', '濁響', '稍糊', '稍凹', '軟粘', 0.481, 0.149, '好瓜'],# 8['烏黑', '稍蜷', '濁響', '清晰', '稍凹', '硬滑', 0.437, 0.211, '好瓜'],# ----------------------------------------------------# 9['烏黑', '稍蜷', '沉悶', '稍糊', '稍凹', '硬滑', 0.666, 0.091, '壞瓜'],# 10['青綠', '硬挺', '清脆', '清晰', '平坦', '軟粘', 0.243, 0.267, '壞瓜'],# 11['淺白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 0.245, 0.057, '壞瓜'],############### 12['淺白', '蜷縮', '濁響', '模糊', '平坦', '軟粘', 0.343, 0.099, '壞瓜'],############ 13['青綠', '稍蜷', '濁響', '稍糊', '凹陷', '硬滑', 0.639, 0.161, '壞瓜'],# 14['淺白', '稍蜷', '沉悶', '稍糊', '凹陷', '硬滑', 0.657, 0.198, '壞瓜'],############ 15['烏黑', '稍蜷', '濁響', '清晰', '稍凹', '軟粘', 0.360, 0.370, '壞瓜'],# 16['淺白', '蜷縮', '濁響', '模糊', '平坦', '硬滑', 0.593, 0.042, '壞瓜'],###########3# 17['青綠', '蜷縮', '沉悶', '稍糊', '稍凹', '硬滑', 0.719, 0.103, '壞瓜']]#下面是西瓜數據集3.0a# dataSet = [# # 1# [0.697, 0.460, '好瓜'],# # 2# [0.774, 0.376, '好瓜'],# # 3# [0.634, 0.264, '好瓜'],# # 4# [0.608, 0.318, '好瓜'],# # 5# [0.556, 0.215, '好瓜'],# # 6# [0.403, 0.237, '好瓜'],# # 7# [0.481, 0.149, '好瓜'],# # 8# [0.437, 0.211, '好瓜'],# # ----------------------------------------------------# # 9# [0.666, 0.091, '壞瓜'],# # 10# [0.243, 0.267, '壞瓜'],# # 11# [0.245, 0.057, '壞瓜'],# # 12# [ 0.343, 0.099, '壞瓜'],# # 13# [ 0.639, 0.161, '壞瓜'],# # 14# [0.657, 0.198, '壞瓜'],# # 15# [0.360, 0.370, '壞瓜'],# # 16# [0.593, 0.042, '壞瓜'],# # 17# [ 0.719, 0.103, '壞瓜']# ]# 西瓜數據集3.0特征列表labels = ['色澤', '根蒂', '敲擊', '紋理', '臍部', '觸感', '密度', '含糖率']# 西瓜數據集3.0a特征列表# labels = ['密度', '含糖率']# 特征對應的所有可能的情況labels_full = {}for i in range(len(labels)):labelList = [example[i] for example in dataSet]uniqueLabel = set(labelList)labels_full[labels[i]] = uniqueLabelprint("--------------------------------------")for item in labels_full:print("item=",unicode(item))print("--------------------------------------")print("len(labels_full)=",len(labels_full))print("len(labels)=",len(labels))return dataSet, labels, labels_fulldef calcShannonEnt(dataSet):"""計算給定數據集的信息熵(香農熵):param dataSet::return:"""# 計算出數據集的總數numEntries = len(dataSet)# 用來統計標簽labelCounts = collections.defaultdict(int)# 循環整個數據集,得到數據的分類標簽for featVec in dataSet:# 得到當前的標簽currentLabel = featVec[-1]# 將對應的標簽值加一labelCounts[currentLabel] += 1# 默認的信息熵shannonEnt = 0.0for key in labelCounts:# 計算出當前分類標簽占總標簽的比例數prob = float(labelCounts[key]) / numEntries# 以2為底求對數shannonEnt -= prob * log(prob, 2)return shannonEntdef splitDataSetForSeries(dataSet, axis, value):print("進入splitDataSetForSeries,axis=",axis)"""按照給定的數值,將數據集分為不大于和大于兩部分:param dataSet: 要劃分的數據集:param i: 特征值所在的下標:param value: 劃分值:return:"""# 用來保存不大于劃分值的集合eltDataSet = []# 用來保存大于劃分值的集合gtDataSet = []# 進行劃分,保留該特征值print("axis=",axis)for feat in dataSet:if feat[axis] <= value:eltDataSet.append(feat)else:gtDataSet.append(feat)return eltDataSet, gtDataSetdef splitDataSet(dataSet, axis, value):"""按照給定的特征值,將數據集劃分:param dataSet: 數據集:param axis: 給定特征值的坐標:param value: 給定特征值滿足的條件,只有給定特征值等于這個value的時候才會返回:return:"""# 創建一個新的列表,防止對原來的列表進行修改retDataSet = []# 遍歷整個數據集for featVec in dataSet:# 如果給定特征值等于想要的特征值if featVec[axis] == value:# 將該特征值前面的內容保存起來reducedFeatVec = featVec[:axis]# 將該特征值后面的內容保存起來,所以將給定特征值給去掉了reducedFeatVec.extend(featVec[axis + 1:])# 添加到返回列表中retDataSet.append(reducedFeatVec)return retDataSet#這個函數是在尋找最佳分割點,使得熵增益最大.
def calcInfoGainForSeries(dataSet, i, baseEntropy):print("進入calcInfoGainForSeries,i=",i)"""計算連續值的信息增益:param dataSet:整個數據集:param i: 對應的特征值下標:param baseEntropy: 基礎信息熵:return: 返回一個信息增益值,和當前的劃分點"""# 記錄最大的信息增益maxInfoGain = 0.0# 最好的劃分點bestMid = -1# 得到數據集中所有的當前特征值列表featList = [example[i] for example in dataSet]# 得到分類列表classList = [example[-1] for example in dataSet]dictList = dict(zip(featList, classList))# 將其從小到大排序,按照連續值的大小排列sortedFeatList = sorted(dictList.items(), key=operator.itemgetter(0))# 計算連續值有多少個numberForFeatList = len(sortedFeatList)# midFeatList = [round((sortedFeatList[i][0] + sortedFeatList[i+1][0])/2.0, 3)for i in range(numberForFeatList - 1)]midFeatList = [round((sortedFeatList[k][0] + sortedFeatList[k+1][0])/2.0, 3)for k in range(numberForFeatList - 1)]#上面一句代碼注意:# 由于作者在這里使用的是python3.x的語法,所以原有代碼中列表推導式中的i會干擾calcInfoGainForSeries(dataSet, i, baseEntropy)中的i#所以為了避免python解釋器混淆,上面的i->k# 計算出各個劃分點信息增益for mid in midFeatList:# 將連續值劃分為不大于當前劃分點和大于當前劃分點兩部分eltDataSet, gtDataSet = splitDataSetForSeries(dataSet, i, mid)# 計算兩部分的特征值熵和權重的乘積之和newEntropy = float(len(eltDataSet))/float(len(sortedFeatList))*float(calcShannonEnt(eltDataSet)) + float(len(gtDataSet))/float(len(sortedFeatList))*float(calcShannonEnt(gtDataSet))# 計算出信息增益infoGain = baseEntropy - newEntropy# print('當前劃分值為:' + str(mid) + ',此時的信息增益為:' + str(infoGain))if infoGain > maxInfoGain:bestMid = midmaxInfoGain = infoGainreturn maxInfoGain, bestMiddef calcInfoGain(dataSet ,featList, i, baseEntropy):"""計算信息增益:param dataSet: 數據集:param featList: 當前特征列表:param i: 當前特征值下標:param baseEntropy: 基礎信息熵:return:"""# 將當前特征唯一化,也就是說當前特征值中共有多少種uniqueVals = set(featList)# 新的熵,代表當前特征值的熵newEntropy = 0.0# 遍歷現在有的特征的可能性for value in uniqueVals:# 在全部數據集的當前特征位置上,找到該特征值等于當前值的集合subDataSet = splitDataSet(dataSet=dataSet, axis=i, value=value)# 計算出權重prob = float(len(subDataSet)) / float(len(dataSet))# 計算出當前特征值的熵newEntropy += prob * calcShannonEnt(subDataSet)# 計算出“信息增益”infoGain = baseEntropy - newEntropyreturn infoGaindef chooseBestFeatureToSplit(dataSet, labels):"""選擇最好的數據集劃分特征,根據信息增益值來計算,可處理連續值:param dataSet::return:"""# 得到數據的特征值總數numFeatures = len(dataSet[0]) - 1# 計算出基礎信息熵baseEntropy = calcShannonEnt(dataSet)# 基礎信息增益為0.0bestInfoGain = 0.0# 最好的特征值bestFeature = -1# 標記當前最好的特征值是不是連續值flagSeries = 0# 如果是連續值的話,用來記錄連續值的劃分點bestSeriesMid = 0.0# 對每個特征值進行求信息熵for i in range(numFeatures):print("i=",i)# 得到數據集中所有的當前特征值列表featList = [example[i] for example in dataSet]if isinstance(featList[0], str):infoGain = calcInfoGain(dataSet, featList, i, baseEntropy)else:# print('當前劃分屬性為:' + str(labels[i]))infoGain, bestMid = calcInfoGainForSeries(dataSet, i, baseEntropy)# print('當前特征值為:' + labels[i] + ',對應的信息增益值為:' + str(infoGain))# 如果當前的信息增益比原來的大if infoGain > bestInfoGain:# 最好的信息增益bestInfoGain = infoGain# 新的最好的用來劃分的特征值bestFeature = iflagSeries = 0if not isinstance(dataSet[0][bestFeature], str):flagSeries = 1bestSeriesMid = bestMid# print('信息增益最大的特征為:' + labels[bestFeature])if flagSeries:return bestFeature, bestSeriesMidelse:return bestFeaturedef majorityCnt(classList):"""找到次數最多的類別標簽:param classList::return:"""# 用來統計標簽的票數classCount = collections.defaultdict(int)# 遍歷所有的標簽類別for vote in classList:classCount[vote] += 1# 從大到小排序sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)# 返回次數最多的標簽return sortedClassCount[0][0]def createTree(dataSet, labels):"""創建決策樹:param dataSet: 數據集:param labels: 特征標簽:return:"""# 拿到所有數據集的分類標簽classList = [example[-1] for example in dataSet]# 統計第一個標簽出現的次數,與總標簽個數比較,如果相等則說明當前列表中全部都是一種標簽,此時停止劃分if classList.count(classList[0]) == len(classList):return classList[0]# 計算第一行有多少個數據,如果只有一個的話說明所有的特征屬性都遍歷完了,剩下的一個就是類別標簽if len(dataSet[0]) == 1:# 返回剩下標簽中出現次數較多的那個return majorityCnt(classList)# 選擇最好的劃分特征,得到該特征的下標bestFeat = chooseBestFeatureToSplit(dataSet=dataSet, labels=labels)# 得到最好特征的名稱bestFeatLabel = ''# 記錄此刻是連續值還是離散值,1連續,2離散flagSeries = 0# 如果是連續值,記錄連續值的劃分點midSeries = 0.0# 如果是元組的話,說明此時是連續值if isinstance(bestFeat, tuple):# 重新修改分叉點信息bestFeatLabel = str(labels[bestFeat[0]]) + '小于' + str(bestFeat[1]) + '?'# 得到當前的劃分點midSeries = bestFeat[1]# 得到下標值bestFeat = bestFeat[0]# 連續值標志flagSeries = 1else:# 得到分叉點信息bestFeatLabel = labels[bestFeat]# 離散值標志flagSeries = 0# 使用一個字典來存儲樹結構,分叉處為劃分的特征名稱myTree = {bestFeatLabel: {}}# 得到當前特征標簽的所有可能值featValues = [example[bestFeat] for example in dataSet]# 連續值處理if flagSeries:# 將連續值劃分為不大于當前劃分點和大于當前劃分點兩部分eltDataSet, gtDataSet = splitDataSetForSeries(dataSet, bestFeat, midSeries)# 得到剩下的特征標簽subLabels = labels[:]# 遞歸處理小于劃分點的子樹subTree = createTree(eltDataSet, subLabels)myTree[bestFeatLabel]['小于'] = subTree# 遞歸處理大于當前劃分點的子樹subTree = createTree(gtDataSet, subLabels)myTree[bestFeatLabel]['大于'] = subTreereturn myTree# 離散值處理else:# 將本次劃分的特征值從列表中刪除掉del (labels[bestFeat])# 唯一化,去掉重復的特征值uniqueVals = set(featValues)# 遍歷所有的特征值for value in uniqueVals:# 得到剩下的特征標簽subLabels = labels[:]# 遞歸調用,將數據集中該特征等于當前特征值的所有數據劃分到當前節點下,遞歸調用時需要先將當前的特征去除掉subTree = createTree(splitDataSet(dataSet=dataSet, axis=bestFeat, value=value), subLabels)# 將子樹歸到分叉處下myTree[bestFeatLabel][value] = subTreereturn myTreeif __name__ == '__main__':dataSet, labels, labels_full = createDataSet()myTree = createTree(dataSet, labels)print(myTree)treePlotter.createPlot(myTree)
treePlotter.py
#-*- coding:utf-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('utf-8')import matplotlib.pyplot as plt
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"]import matplotlib
from matplotlib.font_manager import *
import matplotlib.pyplot as plt
plt.rcParams['axes.unicode_minus']=False
import numpy as np
import pandas as pd
from numpy import *
#首先確保自己系統中安裝了下面兩種字體,下面的這句代碼經過測試,目前直接在修改matplotlibrc
matplotlib.rcParams['font.sans-serif'] = 'HYQuanTangShiF,Times New Roman'#中文除外的設置成New Roman,中文設置成漢儀全唐詩體繁
plt.rcParams['axes.unicode_minus'] = FalsedecisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")#返回葉子數量
def getNumLeafs(myTree):numLeafs = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodesnumLeafs += getNumLeafs(secondDict[key])else: numLeafs +=1return numLeafs#返回樹的深度
def getTreeDepth(myTree):maxDepth = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodesthisDepth = 1 + getTreeDepth(secondDict[key])else: thisDepth = 1if thisDepth > maxDepth: maxDepth = thisDepthreturn maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType):nodeTxt=unicode(nodeTxt)createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, unicode(txtString), va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split onnodeTxt=unicode(nodeTxt)numLeafs = getNumLeafs(myTree) #this determines the x width of this treedepth = getTreeDepth(myTree)firstStr = myTree.keys()[0] #the text label for this node should be thiscntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursionelse: #it's a leaf node print the leaf nodeplotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWprint("plotNode=",plotNode)print("type(plotNode)=",type(plotNode))print("leafNode=",leafNode)print("type(leafNode)=",type(leafNode))plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;plotTree(inTree, (0.5,1.0), '')plt.show()#def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()def retrieveTree(i):listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]return listOfTrees[i]#createPlot(thisTree)
西瓜數據集3.0(數據集在代碼中自帶)
用來繪制書上的圖4.8
在周志華<機器學習>第85頁
西瓜數據集3.0a(數據集在代碼中自帶)
用來繪制書上的圖4.10,
在周志華<機器學習>第90頁
總結
以上是生活随笔為你收集整理的ID3决策树中连续值的处理+周志华《機器學習》图4.8和图4.10绘制的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。