java knn kd树_KNN算法之KD树(K-dimension Tree)实现 K近邻查询
KD樹是一種分割k維數據空間的數據結構,主要應用于多維空間關鍵數據的搜索,如范圍搜索和最近鄰搜索。
KD樹使用了分治的思想,對比二叉搜索樹(BST),KD樹解決的是多維空間內的最近點(K近點)問題。(思想與之前見過的最近點對問題很相似,將所有點分為兩邊,對于可能橫跨劃分線的點對再進一步討論)
KD樹用來優化KNN算法中的查詢復雜度。
一、建樹
建立KDtree,主要有兩步操作:選擇合適的分割維度,選擇中值節點作為分割節點。
分割維度的選擇遵循的原則是,選擇范圍最大的緯度,也即是方差最大的緯度作為分割維度,為什么方差最大的適合作為特征呢?
因為方差大,數據相對“分散”,選取該特征來對數據集進行分割,數據散得更“開”一些。
分割節點的選擇原則是,將這一維度的數據進行排序,選擇正中間的節點作為分割節點,確保節點左邊的點的維度值小于節點的維度值,節點右邊的點的維度值大于節點的維度值。
這兩步步驟影響搜索效率,非常關鍵。
二、搜索K近點
需要的數據結構:最大堆(此處我對距離取負從而用最小堆實現的最大堆,因為python的heapq模塊只有最小堆)、堆棧(用列表實現)
a.利用二叉搜索找到葉子節點并將搜索的結點路徑壓入堆棧stack中
b.通過保存在堆棧中的搜索路徑回溯,直至堆棧中沒有結點了
對于b步驟,需要區分葉子結點和非葉結點:
1、葉子結點:
葉子結點:計算與目標點的距離。若候選堆中不足K個點,將葉子結點加入候選堆中;如果K個點夠了,判斷是否比候選堆中距離最小的結點(因為距離取了相反數)還要大, 說明應當加入候選堆;
2、非葉結點:
對于非葉結點,處理步驟和葉子結點差不多,只是需要額外考慮以目標點為圓心,最大堆的堆頂元素為半徑的超球體是否和劃分當前空間的超平面相交,如果相交說明未訪問的另一邊的空間有可能包含比當前已有的K個近點更近的點,需要搜索另一邊的空間;此外,當候選堆中沒有K個點,那么不管有沒有相交,都應當搜索未訪問的另一邊空間,因為候選堆的點不夠K個。
步驟:計算與目標點的距離
1、若不足K個點,將結點加入候選堆中;
如果K個點夠了,判斷是否比候選堆中距離最小的結點(因為距離取了相反數)還要大。
2、判斷候選堆中的最小距離是否小于Xi離當前超平面的距離(即是否需要判斷未訪問的另一邊要不要搜索)當然如果不足K個點,雖然超平面不相交,依舊要搜索另一邊,直到找到葉子結點,并且把路徑加入回溯棧中。
三、預測
KNN通常用來分類或者回歸問題,筆者已經封裝好了兩種預測的方法。
python代碼實現:
1 importheapq2 classKDNode(object):3 def __init__(self,feature=None,father=None,left=None,right=None,split=None):4 self.feature=feature #dimension index (按第幾個維度的特征進行劃分的)
5 self.father=father #并沒有用到6 self.left=left7 self.right=right8 self.split=split #X value and Y value (元組,包含特征X和真實值Y)
9
10 classKDTree(object):11 def __init__(self):12 self.root=KDNode()13 pass
14
15 def_get_variance(self,X,row_indexes,feature_index):16 #X (2D list): samples * dimension
17 #row_indexes (1D list): choose which row can be calculated
18 #feature_index (int): calculate which dimension
19 n =len(row_indexes)20 sum1 =021 sum2 =022 for id inrow_indexes:23 sum1 = sum1 +X[id][feature_index]24 sum2 = sum2 + X[id][feature_index]**2
25
26 return sum2/n - (sum1/n)**2
27
28 def_get_max_variance_feature(self,X,row_indexes):29 mx_var = -1
30 dim_index = -1
31 for dim inrange(len(X[0])):32 dim_var =self._get_variance(X,row_indexes,dim)33 if dim_var>mx_var:34 mx_var=dim_var35 dim_index=dim36 #return max variance feature index (int)
37 returndim_index38
39 def_get_median_index(self,X,row_indexes,feature_index):40 median_index = len(row_indexes)//2
41 select_X = [(idx,X[idx][feature_index]) for idx inrow_indexes]42 sorted_X =select_X43 sorted(sorted_X,key= lambda x:x[1])44 #return median index in feature_index dimension (int)
45 returnsorted_X[median_index][0]46
47 def_split_feature(self,X,row_indexes,feature_index,median_index):48 left_ids =[]49 right_ids =[]50 median_val =X[median_index][feature_index]51 for id inrow_indexes:52 if id==median_index:53 continue
54 val =X[id][feature_index]55 if val <56 left_ids.append else:58 right_ids.append points index and right>
60 #把當前的樣本按feature維度進行劃分為兩份
61 returnleft_ids, right_ids62
63 defbuild_tree(self,X,Y):64 row_indexes =[i for i inrange(len(X))]65 node =self.root66 queue =[(node,row_indexes)]67 #BFS創建KD樹
68 whilequeue:69 root,ids =queue.pop(0)70 if len(ids)==1:71 root.feature = 0 #如果是葉子結點,維度賦0
72 root.split =(X[ids[0]],Y[ids[0]])73 continue
74 #選取方差最大的特征維度劃分,取樣本的中位數作為median
75 feature_index =self._get_max_variance_feature(X,ids)76 median_index =self._get_median_index(X,ids,feature_index)77 left_ids,right_ids =self._split_feature(X,ids,feature_index,median_index)78 root.feature =feature_index79 root.split =(X[median_index],Y[median_index])80 ifleft_ids:81 root.left =KDNode()82 root.left.father =root83 queue.append((root.left,left_ids))84 ifright_ids:85 root.right =KDNode()86 root.right.father =root87 queue.append((root.right,right_ids))88
89 def _get_distance(self,Xi,node,p=2):90 #p=2 default Euclidean distance
91 nx =node.split[0]92 dist =093 for i inrange(len(Xi)):94 dist=dist + (abs(Xi[i]-nx[i])**p)95 dist = dist**(1/p)96 returndist97
98 def_get_hyperplane_distance(self,Xi,node):99 xx =node.split[0]100 dist = abs(Xi[node.feature] -xx[node.feature])101 returndist102
103 def_is_leaf(self,node):104 if node.left is None and node.right isNone:105 returnTrue106 else:107 returnFalse108
109 def get_nearest_neighbour(self,Xi,K=1):110 search_paths =[]111 max_heap = [] #use min heap achieve max heap (因為python只有最小堆)
112 priority_num = 0 #remove same distance
113 heapq.heappush(max_heap,(float('-inf'),priority_num,None))114 priority_num +=1
115 node =self.root116 #找到離Xi最近的葉子結點
117 while node is notNone:118 search_paths.append(node)119 if Xi[node.feature]
124 whilesearch_paths:125 now =search_paths.pop()126 #葉子結點:計算與Xi的距離,若不足K個點,將葉子結點加入候選堆中;
127 #如果K個點夠了,判斷是否比候選堆中距離最小的結點(因為距離取了相反數)還要大,
128 #說明應當加入候選堆;
129 ifself._is_leaf(now):130 dist =self._get_distance(Xi,now)131 dist = -dist132 mini_dist =max_heap[0][0]133 if len(max_heap)
136 elif dist >mini_dist:137 _ =heapq.heappop(max_heap)138 heapq.heappush(max_heap,(dist,priority_num,now))139 priority_num+=1
140 #非葉結點:計算與Xi的距離
141 #1、若不足K個點,將結點加入候選堆中;
142 #如果K個點夠了,判斷是否比候選堆中距離最小的結點(因為距離取了相反數)還要大,
143 #2、判斷候選堆中的最小距離是否小于Xi離當前超平面的距離(即是否需要判斷另一邊要不要搜索)
144 #當然如果不足K個點,雖然超平面不相交,依舊要搜索另一邊,
145 #直到找到葉子結點,并且把路徑加入回溯棧中
146 else:147 dist =self._get_distance(Xi, now)148 dist = -dist149 mini_dist =max_heap[0][0]150 if len(max_heap) <151 heapq.heappush priority_num now>
153 elif dist >mini_dist:154 _ =heapq.heappop(max_heap)155 heapq.heappush(max_heap, (dist, priority_num, now))156 priority_num += 1
157
158 mini_dist =max_heap[0][0]159 if len(max_heap)mini_dist:160 #search another child tree
161 if Xi[now.feature] >=now.split[0][now.feature]:162 child_node =now.left163 else:164 child_node =now.right165 #record path until find child leaf node
166 while child_node is notNone:167 search_paths.append(child_node)168 if Xi[child_node.feature]
174 def predict_classification(self,Xi,K=1):175 #多分類問題預測
176 y =self.get_nearest_neighbour(Xi,K)177 mp ={}178 for i iny:179 if i[2].split[1] inmp:180 mp[i[2].split[1]]+=1
181 else:182 mp[i[2].split[1]]=1
183 pre_y = -1
184 max_cnt =-1
185 for k,v inmp.items():186 if v>max_cnt:187 max_cnt=v188 pre_y=k189 returnpre_y190
191 def predict_regression(self,Xi,K=1):192 #回歸問題預測
193 pre_y =self.get_nearest_neighbour(Xi,K)194 return sum([i[2].split[1] for i in pre_y])/K195
196
197 #t =KDTree()
198 #xx = [[3,3],[1,2],[5,6],[999,999],[5,5]]
199 #z = [1,0,1,1,1]
200 #t.build_tree(xx,z)
201 #y=t.predict_regression([4,5.5],K=5)
202 #print(y)
參考資料:
《統計學習方法》——李航著
https://blog.csdn.net/qq_32478489/article/details/82972391?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param
https://zhuanlan.zhihu.com/p/45346117
https://www.cnblogs.com/xingzhensun/p/9693362.html
151>56>總結
以上是生活随笔為你收集整理的java knn kd树_KNN算法之KD树(K-dimension Tree)实现 K近邻查询的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 熟酸枣仁的功效与作用、禁忌和食用方法
- 下一篇: 2023新一代人工智能(深圳)创业大赛评