KD Tree的原理及Python实现
1. 原理篇
我們用大白話講講KD-Tree是怎么一回事。
1.1 線性查找
假設(shè)數(shù)組A為[0, 6, 3, 8, 7, 4, 11],有一個元素x,我們要找到數(shù)組A中距離x最近的元素,應(yīng)該如何實現(xiàn)呢?比較直接的想法是用數(shù)組A中的每一個元素與x作差,差的絕對值最小的那個元素就是我們要找的元素。假設(shè)x = 2,那么用數(shù)組A中的所有元素與x作差得到[-2, 4, 1, 6, 5, 2, 9],其中絕對值最小的是1,對應(yīng)的元素是數(shù)組A中的3,所以3就是我們的查找結(jié)果。
1.2 二分查找
如果我們有大量的元素要在數(shù)組A中進行查找,那么1.1的方式就顯得不是那么高效了,如果數(shù)組A的長度為N,那么每次查找都要進行N次操作,即算法復(fù)雜度為O(N)。
這種查找方法就是二分查找,其算法復(fù)雜度為O(Log2(N))。
1.3 BST
除了數(shù)組之外,有沒有更直觀的數(shù)據(jù)結(jié)構(gòu)可以實現(xiàn)1.2的二分查找呢?答案就是二分查找樹,全稱Binary Search Tree,簡稱BST。把數(shù)組A建立成一個BST,結(jié)構(gòu)如下圖所示。我們只需要訪問根節(jié)點,進行值比較來確定下一節(jié)點,如此循環(huán)往復(fù)直到訪問到葉子節(jié)點為止。
1.4 多維數(shù)組
現(xiàn)在我們把問題加點難度,假設(shè)數(shù)組B為[[6, 2], [6, 3], [3, 5], [5, 0], [1, 2], [4, 9], [8, 1]],有一個元素x,我們要找到數(shù)組B中距離x最近的元素,應(yīng)該如何實現(xiàn)呢?比較直接的想法是用數(shù)組B中的每一個元素與x求距離,距離最小的那個元素就是我們要找的元素。假設(shè)x = [1, 1],那么用數(shù)組A中的所有元素與x求距離得到[5.0, 5.4, 4.5, 4.1, 1.0, 8.5, 7.0],其中距離最小的是1,對應(yīng)的元素是數(shù)組B中的[1, 2],所以[1, 2]就是我們的查找結(jié)果。
1.5 再次陷入困境
如果我們有大量的元素要在數(shù)組B中進行查找,那么1.4的方式就又顯得不是那么高效了,如果數(shù)組B的長度為N,那么每次查找都要進行N次操作,即算法復(fù)雜度為O(N)。
1.6 什么是KD-Tree
這時候已經(jīng)沒辦法用BST,不過我們可以對BST做一些改變來適應(yīng)多維數(shù)組的情況。當(dāng)當(dāng)當(dāng)當(dāng)~,這時候該KD-Tree出場了。廢話不多說,先上圖:
1.7 如何建立KD-Tree
您可能會問,剛在那張圖的KD Tree又是如何建立的呢? 很簡單,只需要5步:
1. 建立根節(jié)點;
2. 選取方差最大的特征作為分割特征;
3. 選擇該特征的中位數(shù)作為分割點;
4. 將數(shù)據(jù)集中該特征小于中位數(shù)的傳遞給根節(jié)點的左兒子,大于中位數(shù)的傳遞給根節(jié)點的右兒子;
5. 遞歸執(zhí)行步驟2-4,直到所有數(shù)據(jù)都被建立到KD Tree的節(jié)點上為止。
不難看出,KD Tree的建立步驟跟BST是非常相似的,可以認(rèn)為BST是KD Tree在一維數(shù)據(jù)上的特例。KD Tree的算法復(fù)雜度介于O(Log2(N))和O(N)之間。
1.8 特征選取
您可能還會問,為什么方差最大的適合作為特征呢? 因為方差大,數(shù)據(jù)相對“分散”,選取該特征來對數(shù)據(jù)集進行分割,數(shù)據(jù)散得更“開”一些。
1.9 分割點選擇
您可能又要問,為什么選擇中位數(shù)作為分割點呢? 因為借鑒了BST,選取中位數(shù),讓左子樹和右子樹的數(shù)據(jù)數(shù)量一致,便于二分查找。
1.10 利用KD-Tree查找元素
KD Tree建好之后,接下來就要利用KD Tree對元素進行查找了。查找的方式在BST的基礎(chǔ)上又增加了一些難度,如下:
1. 從根節(jié)點開始,根據(jù)目標(biāo)在分割特征中是否小于或大于當(dāng)前節(jié)點,向左或向右移動。
2. 一旦算法到達葉節(jié)點,它就將節(jié)點點保存為“當(dāng)前最佳”。
3. 回溯,即從葉節(jié)點再返回到根節(jié)點
4. 如果當(dāng)前節(jié)點比當(dāng)前最佳節(jié)點更接近,那么它就成為當(dāng)前最好的。
5. 如果目標(biāo)距離當(dāng)前節(jié)點的父節(jié)點所在的將數(shù)據(jù)集分割為兩份的超平面的距離更接近,說明當(dāng)前節(jié)點的兄弟節(jié)點所在的子樹有可能包含更近的點。因此需要對這個兄弟節(jié)點遞歸執(zhí)行1-4步。
1.11 超平面
所以什么是超平面呢,聽起來讓人一臉懵逼。
以[0, 2, 0], [1, 4, 3], [2, 6, 1]的舉例:
1. 如果用第二維特征作為分割特征,那么從三個數(shù)據(jù)點中的對應(yīng)特征取出2, 4, 6,中位數(shù)是4;
2. 所以[1, 4, 3]作為分割點,將[0, 2, 0]劃分到左邊,[2, 6, 1]劃分到右邊;
3. 從立體幾何的角度考慮,三維空間得用一個二維的平面才能把空間一分為二,這個平面可以用y = 4來表示;
4. 點[0, 2, 0]到超平面y = 4的距離就是 sqrt((2 - 4) ^ 2) = 2;
5. 點[2, 6, 1]到超平面y = 4的距離就是 sqrt((6 - 4) ^ 2) = 2。
2. 實現(xiàn)篇
本人用全宇宙最簡單的編程語言——Python實現(xiàn)了KD-Tree算法,沒有依賴任何第三方庫,便于學(xué)習(xí)和使用。簡單說明一下實現(xiàn)過程,更詳細的注釋請參考本人github上的代碼。
2.1 創(chuàng)建Node類
初始化,存儲父節(jié)點、左節(jié)點、右節(jié)點、特征及分割點。
class Node(object):def __init__(self):self.father = Noneself.left = Noneself.right = Noneself.feature = Noneself.split = None2.2 獲取Node的各個屬性
def __str__(self):return "feature: %s, split: %s" % (str(self.feature), str(self.split))2.3 獲取Node的兄弟節(jié)點
@property def brother(self):if self.father is None:ret = Noneelse:if self.father.left is self:ret = self.father.rightelse:ret = self.father.leftreturn ret2.4 創(chuàng)建KDTree類
初始化,存儲根節(jié)點。
class KDTree(object):def __init__(self):self.root = Node()2.5 獲取KDTree屬性
便于我們查看KD Tree的節(jié)點值,各個節(jié)點之間的關(guān)系。
def __str__(self):ret = []i = 0que = [(self.root, -1)]while que:nd, idx_father = que.pop(0)ret.append("%d -> %d: %s" % (idx_father, i, str(nd)))if nd.left is not None:que.append((nd.left, i))if nd.right is not None:que.append((nd.right, i))i += 1return "\n".join(ret)2.6 獲取數(shù)組中位數(shù)的下標(biāo)
def _get_median_idx(self, X, idxs, feature):n = len(idxs)k = n // 2col = map(lambda i: (i, X[i][feature]), idxs)sorted_idxs = map(lambda x: x[0], sorted(col, key=lambda x: x[1]))median_idx = list(sorted_idxs)[k]return median_idx2.7 計算特征的方差
注意這里用到了方差公式,D(X) = E(X^2)-[E(X)]^2
def _get_variance(self, X, idxs, feature):n = len(idxs)col_sum = col_sum_sqr = 0for idx in idxs:xi = X[idx][feature]col_sum += xicol_sum_sqr += xi ** 2return col_sum_sqr / n - (col_sum / n) ** 22.8 選擇特征
取方差最大的的特征作為分割點特征。
def _choose_feature(self, X, idxs):m = len(X[0])variances = map(lambda j: (j, self._get_variance(X, idxs, j)), range(m))return max(variances, key=lambda x: x[1])[0]2.9 分割特征
把大于、小于中位數(shù)的元素分別放到兩個列表中。
def _split_feature(self, X, idxs, feature, median_idx):idxs_split = [[], []]split_val = X[median_idx][feature]for idx in idxs:if idx == median_idx:continuexi = X[idx][feature]if xi < split_val:idxs_split[0].append(idx)else:idxs_split[1].append(idx)return idxs_split2.10 建立KDTree
使用廣度優(yōu)先搜索的方式建立KD Tree,注意要對X進行歸一化。
def build_tree(self, X, y):X_scale = min_max_scale(X)nd = self.rootidxs = range(len(X))que = [(nd, idxs)]while que:nd, idxs = que.pop(0)n = len(idxs)if n == 1:nd.split = (X[idxs[0]], y[idxs[0]])continuefeature = self._choose_feature(X_scale, idxs)median_idx = self._get_median_idx(X, idxs, feature)idxs_left, idxs_right = self._split_feature(X, idxs, feature, median_idx)nd.feature = featurend.split = (X[median_idx], y[median_idx])if idxs_left != []:nd.left = Node()nd.left.father = ndque.append((nd.left, idxs_left))if idxs_right != []:nd.right = Node()nd.right.father = ndque.append((nd.right, idxs_right))2.11 搜索輔助函數(shù)
比較目標(biāo)元素與當(dāng)前結(jié)點的當(dāng)前feature,訪問對應(yīng)的子節(jié)點。反復(fù)執(zhí)行上述過程,直到到達葉子節(jié)點。
def _search(self, Xi, nd):while nd.left or nd.right:if nd.left is None:nd = nd.rightelif nd.right is None:nd = nd.leftelse:if Xi[nd.feature] < nd.split[0][nd.feature]:nd = nd.leftelse:nd = nd.rightreturn nd2.12 歐氏距離
計算目標(biāo)元素與某個節(jié)點的歐氏距離,注意get_euclidean_distance這個函數(shù)沒有進行開根號的操作,所以求出來的是歐氏距離的平方。
def _get_eu_dist(self, Xi, nd):X0 = nd.split[0]return get_euclidean_distance(Xi, X0)2.13 超平面距離
計算目標(biāo)元素與某個節(jié)點所在超平面的歐氏距離,為了跟2.11保持一致,要加上平方。
def _get_hyper_plane_dist(self, Xi, nd):j = nd.featureX0 = nd.split[0]return (Xi[j] - X0[j]) ** 22.14 搜索函數(shù)
搜索KD Tree中與目標(biāo)元素距離最近的節(jié)點,使用廣度優(yōu)先搜索來實現(xiàn)。
def nearest_neighbour_search(self, Xi):dist_best = float("inf")nd_best = self._search(Xi, self.root)que = [(self.root, nd_best)]while que:nd_root, nd_cur = que.pop(0)while 1:dist = self._get_eu_dist(Xi, nd_cur)if dist < dist_best:dist_best = distnd_best = nd_curif nd_cur is not nd_root:nd_bro = nd_cur.brotherif nd_bro is not None:dist_hyper = self._get_hyper_plane_dist(Xi, nd_cur.father)if dist > dist_hyper:_nd_best = self._search(Xi, nd_bro)que.append((nd_bro, _nd_best))nd_cur = nd_cur.fatherelse:breakreturn nd_best3 效果評估
3.1 線性查找
用“笨”辦法查找距離最近的元素。
def exhausted_search(X, Xi):dist_best = float('inf')row_best = Nonefor row in X:dist = get_euclidean_distance(Xi, row)if dist < dist_best:dist_best = distrow_best = rowreturn row_best3.2 main函數(shù)
主函數(shù)分為如下幾個部分:
1. 隨機生成數(shù)據(jù)集,即測試用例
2. 建立KD-Tree
3. 執(zhí)行“笨”辦法查找
4. 比較“笨”辦法和KD-Tree的查找結(jié)果
def main():def main():print("Testing KD Tree...")test_times = 100run_time_1 = run_time_2 = 0for _ in range(test_times):low = 0high = 100n_rows = 1000n_cols = 2X = gen_data(low, high, n_rows, n_cols)y = gen_data(low, high, n_rows)Xi = gen_data(low, high, n_cols) tree = KDTree()tree.build_tree(X, y)start = time()nd = tree.nearest_neighbour_search(Xi)run_time_1 += time() - startret1 = get_euclidean_distance(Xi, nd.split[0])start = time()row = exhausted_search(X, Xi)run_time_2 += time() - startret2 = get_euclidean_distance(Xi, row)assert ret1 == ret2, "target:%s\nrestult1:%s\nrestult2:%s\ntree:\n%s" \% (str(Xi), str(nd), str(row), str(tree)) print("%d tests passed!" % test_times) print("KD Tree Search %.2f s" % run_time_1) print("Exhausted search %.2f s" % run_time_2)
https://github.com/tushushu/imylu/tree/master/imylu/utils
總結(jié)
以上是生活随笔為你收集整理的KD Tree的原理及Python实现的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python基础入门学习笔记 (2)
- 下一篇: 服务器部署docker