使用Decision Tree对MNIST数据集进行实验
之前已經對MNIST使用過SVM和KNN的方法進行分類,效果看起來還不錯。今天使用決策樹來實驗,看看結果如何。
使用的Decision Tree中,對MNIST中的灰度值進行了0/1處理,方便來進行分類和計算熵。
使用較少的測試數據測試了在對灰度值進行多分類的情況下,分類結果的正確率如何。實驗結果如下。
#Test change pixel data into more categories than 0/1:
#int(pixel)/50: 37%
#int(pixel)/64: 45.9%
#int(pixel)/96: 52.3%
#int(pixel)/128: 62.48%
#int(pixel)/152: 59.1%
#int(pixel)/176: 57.6%
#int(pixel)/192: 54.0%
可見,在對灰度數據進行二分類,也就是0/1處理時,效果是最好的。
使用0/1處理,最終結果如下:
#Result:
#Train with 10k, test with 60k: 77.79%
#Train with 60k, test with 10k: 87.3%
#Time cost: 3 hours.
最終結果是87.3%的正確率。與SVM和KNN的超過95%相比,差距不小。而且消耗時間更長。
需要注意的是,此次Decision Tree算法中,并未對決策樹進行剪枝。因此,還有可以提升的空間。
python代碼見最下面。其中:
calcShannonEntropy(dataSet):是對矩陣的熵進行計算,根據各個數據點的分類情況,使用香農定理計算;
splitDataSet(dataSet, axis, value): 是獲取第axis維度上的值為value的所有行所組成的矩陣。對于第axis維度上的數據,分別計算他們的splitDataSet的矩陣的熵,并與該維度上數據的出現概率相乘求和,可以得到使用第axis維度構建決策樹后,整體的熵。
chooseBestFeatureToSplit(dataSet): 根據splitDataSet函數,對比得到整體的熵與原矩陣的熵相比,熵的增量最大的維度。根據此維度feature來構建決策樹。
createDecisionTree(dataSet, features): 遞歸構建決策樹。若在葉子節點處沒法分類,則采用majorityCnt(classList)方法統計出現最多次的class作為分類。
代碼如下:
#Decision tree for MNIST dataset by arthur503. #Data format: 'class label1:pixel label2:pixel ...' #Warning: without fix overfitting! # #Test change pixel data into more categories than 0/1: #int(pixel)/50: 37% #int(pixel)/64: 45.9% #int(pixel)/96: 52.3% #int(pixel)/128: 62.48% #int(pixel)/152: 59.1% #int(pixel)/176: 57.6% #int(pixel)/192: 54.0% # #Result: #Train with 10k, test with 60k: 77.79% #Train with 60k, test with 10k: 87.3% #Time cost: 3 hours.from numpy import * import operatordef calcShannonEntropy(dataSet):numEntries = len(dataSet)labelCounts = {}for featureVec in dataSet:currentLabel = featureVec[0]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 1else:labelCounts[currentLabel] += 1shannonEntropy = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEntropy -= prob * log2(prob)return shannonEntropy#get all rows whose axis item equals value. def splitDataSet(dataSet, axis, value):subDataSet = []for featureVec in dataSet:if featureVec[axis] == value:reducedFeatureVec = featureVec[:axis]reducedFeatureVec.extend(featureVec[axis+1:]) #if axis == -1, this will cause error!subDataSet.append(reducedFeatureVec)return subDataSetdef chooseBestFeatureToSplit(dataSet):#Notice: Actucally, index 0 of numFeatures is not feature(it is class label).numFeatures = len(dataSet[0]) baseEntropy = calcShannonEntropy(dataSet)bestInfoGain = 0.0bestFeature = numFeatures - 1 #DO NOT use -1! or splitDataSet(dataSet, -1, value) will cause error!#feature index start with 1(not 0)!for i in range(numFeatures)[1:]:featureList = [example[i] for example in dataSet]featureSet = set(featureList)newEntropy = 0.0for value in featureSet:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * calcShannonEntropy(subDataSet)infoGain = baseEntropy - newEntropyif infoGain > bestInfoGain:bestInfoGain = infoGainbestFeature = ireturn bestFeature#classify on leaf of decision tree. def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount:classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]#Create Decision Tree. def createDecisionTree(dataSet, features):print 'create decision tree... length of features is:'+str(len(features))classList = [example[0] for example in dataSet]if classList.count(classList[0]) == len(classList):return classList[0]if len(dataSet[0]) == 1:return majorityCnt(classList)bestFeatureIndex = chooseBestFeatureToSplit(dataSet) bestFeatureLabel = features[bestFeatureIndex]myTree = {bestFeatureLabel:{}}del(features[bestFeatureIndex])featureValues = [example[bestFeatureIndex] for example in dataSet]featureSet = set(featureValues)for value in featureSet:subFeatures = features[:] myTree[bestFeatureLabel][value] = createDecisionTree(splitDataSet(dataSet, bestFeatureIndex, value), subFeatures)return myTreedef line2Mat(line):mat = line.strip().split(' ')for i in range(len(mat)-1): pixel = mat[i+1].split(':')[1]#change MNIST pixel data into 0/1 format.mat[i+1] = int(pixel)/128return mat#return matrix as a list(instead of a matrix). #features is the 28*28 pixels in MNIST dataset. def file2Mat(fileName):f = open(fileName)lines = f.readlines()matrix = []for line in lines:mat = line2Mat(line)matrix.append(mat)f.close()print 'Read file '+str(fileName) + ' to array done! Matrix shape:'+str(shape(matrix))return matrix#Classify test file. def classify(inputTree, featureLabels, testVec):firstStr = inputTree.keys()[0]secondDict = inputTree[firstStr]featureIndex = featureLabels.index(firstStr)predictClass = '-1'for key in secondDict.keys():if testVec[featureIndex] == key:if type(secondDict[key]) == type({}): predictClass = classify(secondDict[key], featureLabels, testVec)else:predictClass = secondDict[key]return predictClassdef classifyTestFile(inputTree, featureLabels, testDataSet):rightCnt = 0for i in range(len(testDataSet)):classLabel = testDataSet[i][0]predictClassLabel = classify(inputTree, featureLabels, testDataSet[i])if classLabel == predictClassLabel:rightCnt += 1 if i % 200 == 0:print 'num '+str(i)+'. ratio: ' + str(float(rightCnt)/(i+1))return float(rightCnt)/len(testDataSet)def getFeatureLabels(length):strs = []for i in range(length):strs.append('#'+str(i))return strs#Normal file trainFile = 'train_60k.txt' testFile = 'test_10k.txt' #Scaled file #trainFile = 'train_60k_scale.txt' #testFile = 'test_10k_scale.txt' #Test file #trainFile = 'test_only_1.txt' #testFile = 'test_only_2.txt'#train decision tree. dataSet = file2Mat(trainFile) #Actually, the 0 item is class, not feature labels. featureLabels = getFeatureLabels(len(dataSet[0])) print 'begin to create decision tree...' myTree = createDecisionTree(dataSet, featureLabels) print 'create decision tree done.'#predict with decision tree. testDataSet = file2Mat(testFile) featureLabels = getFeatureLabels(len(testDataSet[0])) rightRatio = classifyTestFile(myTree, featureLabels, testDataSet) print 'total right ratio: ' + str(rightRatio)總結
以上是生活随笔為你收集整理的使用Decision Tree对MNIST数据集进行实验的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 使用KNN对MNIST数据集进行实验
- 下一篇: SVM的提出