K 近邻法(K-Nearest Neighbor, K-NN)
文章目錄
- 1. k近鄰算法
- 2. k近鄰模型
- 2.1 模型
- 2.2 距離度量
- 2.2.1 距離計算代碼 Python
- 2.3 kkk 值的選擇
- 2.4 分類決策規則
- 3. 實現方法, kd樹
- 3.1 構造 kdkdkd 樹
- Python 代碼
- 3.2 搜索 kdkdkd 樹
- Python 代碼
- 4. 鳶尾花KNN分類
- 4.1 KNN實現
- 4.2 sklearn KNN
- 5. 文章完整代碼
k近鄰法(k-nearest neighbor,k-NN)是一種基本分類與回歸方法。
- 輸入:實例的特征向量,對應于特征空間的點
- 輸出:實例的類別,可以取多類
- 假設:給定一個訓練數據集,其中的實例類別已定。
- 分類:對新的實例,根據其k個最近鄰的訓練實例的類別,通過多數表決等方式進行預測。因此,k近鄰法不具有顯式的學習過程。
- k近鄰法實際上利用訓練數據集對特征向量空間進行劃分,并作為其分類的“模型”。
k近鄰法1968年由Cover和Hart提出。
1. k近鄰算法
輸入:一組訓練數據集,特征向量 xix_ixi?,及其類別 yiy_iyi?,給定實例特征向量 xxx
輸出:實例 xxx 所屬的類 yyy
y=arg?max?cj∑xi∈Nk(x)I(yi=cj),i=1,2,...,N,j=1,2,...,Ky = \argmax\limits_{c_j} \sum\limits_{x_i \in N_k(x) } I(y_i = c_j),\quad i=1,2,...,N, j = 1,2,...,Ky=cj?argmax?xi?∈Nk?(x)∑?I(yi?=cj?),i=1,2,...,N,j=1,2,...,K
III 為指示函數,表示當 yi=cjy_i=c_jyi?=cj? 時 III 為 1, 否則 III 為 0
當 k=1k=1k=1 時,特殊情況,稱為最近鄰算法,跟它距離最近的點作為其分類
2. k近鄰模型
三要素:k值的選擇、距離度量、分類決策規則
2.1 模型
- kkk 近鄰模型,三要素確定后,對于任何一個新的輸入實例,它的類唯一確定。
- 這相當于根據上述要素將特征空間劃分為一些子空間,確定子空間里的每個點所屬的類。這一事實從最近鄰算法中可以看得很清楚。
2.2 距離度量
空間中兩個點的距離是兩個實例相似程度的反映。
- LpL_pLp? 距離:
設特征 xix_ixi? 是 nnn 維的,Lp(xi,xj)=(∑l=1n∣xi(l)?xj(l)∣p)1pL_p(x_i,x_j) = \bigg(\sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|^p \bigg)^{\frac{1}{p}}Lp?(xi?,xj?)=(l=1∑n?∣xi(l)??xj(l)?∣p)p1? - 歐氏距離:上面 p=2p=2p=2 時,L2(xi,xj)=(∑l=1n∣xi(l)?xj(l)∣2)12L_2(x_i,x_j) = \bigg(\sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|^2 \bigg)^{\frac{1}{2}}L2?(xi?,xj?)=(l=1∑n?∣xi(l)??xj(l)?∣2)21?
- 曼哈頓距離:上面 p=1p=1p=1 時,L1(xi,xj)=∑l=1n∣xi(l)?xj(l)∣L_1(x_i,x_j) = \sum\limits_{l=1}^n |x_i^{(l)}-x_j^{(l)}|L1?(xi?,xj?)=l=1∑n?∣xi(l)??xj(l)?∣
- 切比雪夫距離:當 p=∞p=\inftyp=∞ 時,它是坐標距離的最大值:L∞(xi,xj)=max?l∣xi(l)?xj(l)∣L_\infty(x_i,x_j) = \max\limits_l |x_i^{(l)}-x_j^{(l)}|L∞?(xi?,xj?)=lmax?∣xi(l)??xj(l)?∣
2.2.1 距離計算代碼 Python
import mathdef L_p(xi, xj, p=2):if len(xi) == len(xj) and len(xi) > 0:sum = 0for i in range(len(xi)):sum += math.pow(abs(xi[i] - xj[i]), p)return math.pow(sum, 1 / p)else:return 0 x1 = [1, 1] x2 = [5, 1] x3 = [4, 4] X = [x1, x2, x3] for i in range(len(X)):for j in range(i + 1, len(X)):for p in range(1, 5):print("x%d,x%d的L%d距離是:%.2f" % (i + 1, j + 1, p, L_p(X[i], X[j], p))) x1,x2的L1距離是:4.00 x1,x2的L2距離是:4.00 x1,x2的L3距離是:4.00 x1,x2的L4距離是:4.00 x1,x3的L1距離是:6.00 x1,x3的L2距離是:4.24 x1,x3的L3距離是:3.78 x1,x3的L4距離是:3.57 x2,x3的L1距離是:4.00 x2,x3的L2距離是:3.16 x2,x3的L3距離是:3.04 x2,x3的L4距離是:3.012.3 kkk 值的選擇
-
k值的選擇會對k近鄰法的結果產生重大影響。
-
選較小的 k 值,相當于用較小的鄰域中的訓練實例進行預測,“學習”的近似誤差(approximation error)會減小,只有與輸入實例較近的(相似的)訓練實例才會對預測結果起作用。但缺點是“學習”的估計誤差(estimation error)會增大,預測結果會對近鄰的實例點非常敏感。
-
如果鄰近的實例點恰巧是噪聲,預測就會出錯。換句話說,k值的減小就意味著整體模型變得復雜,容易發生過擬合。
-
選較大的 k 值,相當于用較大鄰域中的訓練實例進行預測。優點是可以減少學習的估計誤差,但缺點是學習的近似誤差會增大。這時與輸入實例較遠的(不相似的)訓練實例也會對預測起作用,使預測發生錯誤。
-
k值的增大就意味著整體的模型變得簡單。
-
如果 k=N,無論輸入實例是什么,都將簡單地預測它屬于在訓練實例中最多的類。模型過于簡單,完全忽略大量有用信息,不可取。
-
應用中,k 值一般取一個比較小的數值。通常采用交叉驗證法來選取最優的 k 值。
2.4 分類決策規則
- 多數表決(majority voting rule)
假設損失函數為0-1損失,對于 xix_ixi? 的近鄰域 Nk(x)N_k(x)Nk?(x) 的分類是 cjc_jcj?,那么誤分類率是:
1k∑xi∈Nk(x)I(yi≠cj)=1?1k∑xi∈Nk(x)I(yi=cj)\frac{1}{k} \sum\limits_{x_i \in N_k(x) }I(y_i \neq c_j) = 1- \frac{1}{k}\sum\limits_{x_i \in N_k(x) } I(y_i = c_j)k1?xi?∈Nk?(x)∑?I(yi??=cj?)=1?k1?xi?∈Nk?(x)∑?I(yi?=cj?)
要使誤分類率最小,那么就讓 ∑xi∈Nk(x)I(yi=cj)\sum\limits_{x_i \in N_k(x) } I(y_i = c_j)xi?∈Nk?(x)∑?I(yi?=cj?) 最大,所以選多數的那個類(經驗風險最小化)
3. 實現方法, kd樹
-
算法實現時,需要對大量的點進行距離計算,復雜度是 O(n2)O(n^2)O(n2),訓練集很大時,效率低,不可取
-
考慮特殊的結構存儲訓練數據,以減少計算距離次數,如 kdkdkd 樹
3.1 構造 kdkdkd 樹
kdkdkd 樹是一種對 k 維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構。
- kdkdkd 樹是二叉樹,表示對k維空間的一個劃分(partition)。
- 構造 kdkdkd 樹相當于不斷地用垂直于坐標軸的超平面將 k 維空間切分,構成一系列的k維超矩形區域。
- kdkdkd 樹的每個結點對應于一個 k 維超矩形區域。
構造 kdkdkd 樹的方法:
- 根結點:使根結點對應于k維空間中包含所有實例點的超矩形區域;通過遞歸方法,不斷地對 k 維空間進行切分,生成子結點
- 在超矩形區域(結點)上選擇一個坐標軸和在此坐標軸上的一個切分點,確定一個超平面,將當前超矩形區域切分為左右兩個子區域(子結點)
- 實例被分到兩個子區域。這個過程直到子區域內沒有實例時終止(終止時的結點為葉結點)。在此過程中,將實例保存在相應的結點上。
Python 代碼
class KdNode():def __init__(self, dom_elt, split, left, right):self.dom_elt = dom_elt # k維向量節點(k維空間中的一個樣本點)self.split = split # 整數(進行分割維度的序號)self.left = left # 該結點分割超平面左子空間構成的kd-treeself.right = right # 該結點分割超平面右子空間構成的kd-treeclass KdTree():def __init__(self, data):k = len(data[0]) # 實例的向量維度def CreatNode(split, data_set):if not data_set:return Nonedata_set.sort(key=lambda x: x[split])split_pos = len(data_set) // 2 # 整除median = data_set[split_pos]split_next = (split + 1) % kreturn KdNode(median, split,CreatNode(split_next, data_set[:split_pos]),CreatNode(split_next, data_set[split_pos + 1:]))self.root = CreatNode(0, data)def preorder(self, root):if root:print(root.dom_elt)if root.left:self.preorder(root.left)if root.right:self.preorder(root.right) data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]] kd = KdTree(data) kd.preorder(kd.root)運行結果:
[7, 2] [5, 4] [2, 3] [4, 7] [9, 6] [8, 1]3.2 搜索 kdkdkd 樹
給定目標點,搜索其最近鄰。
- 先找到包含目標點的葉結點
- 從該葉結點出發,依次回退到父結點;不斷查找與目標點最鄰近的結點
- 當確定不可能存在更近的結點時終止。
- 這樣搜索就被限制在空間的局部區域上,效率大為提高。
- 目標點的最近鄰一定在以目標點為中心并通過當前最近點的超球體的內部。
- 然后返回當前結點的父結點,如果父結點的另一子結點的超矩形區域與超球體相交,那么在相交的區域內尋找與目標點更近的實例點。
- 如果存在這樣的點,將此點作為新的當前最近點。算法轉到更上一級的父結點,繼續上述過程。
- 如果父結點的另一子結點的超矩形區域與超球體不相交,或不存在比當前最近點更近的點,則停止搜索。
Python 代碼
from collections import namedtuple# 定義一個namedtuple,分別存放最近坐標點、最近距離和訪問過的節點數 result = namedtuple("Result_tuple","nearest_point nearest_dist nodes_visited")def find_nearest(tree, point):k = len(point) # 數據維度def travel(kd_node, target, max_dist):if kd_node is None:return result([0] * k, float("inf"), 0)# python中用float("inf")和float("-inf")表示正負無窮nodes_visited = 1s = kd_node.split # 進行分割的維度pivot = kd_node.dom_elt # 進行分割的“軸”if target[s] <= pivot[s]: # 如果目標點第s維小于分割軸的對應值(目標離左子樹更近)nearer_node = kd_node.left # 下一個訪問節點為左子樹根節點further_node = kd_node.right # 同時記錄下右子樹else: # 目標離右子樹更近nearer_node = kd_node.right # 下一個訪問節點為右子樹根節點further_node = kd_node.lefttemp1 = travel(nearer_node, target, max_dist) # 進行遍歷找到包含目標點的區域nearest = temp1.nearest_point # 以此葉結點作為“當前最近點”dist = temp1.nearest_dist # 更新最近距離nodes_visited += temp1.nodes_visitedif dist < max_dist:max_dist = dist # 最近點將在以目標點為球心,max_dist為半徑的超球體內temp_dist = abs(pivot[s] - target[s]) # 第s維上目標點與分割超平面的距離if max_dist < temp_dist: # 判斷超球體是否與超平面相交return result(nearest, dist, nodes_visited) # 不相交則可以直接返回,不用繼續判斷# ----------------------------------------------------------------------# 計算目標點與分割點的歐氏距離p = np.array(pivot)t = np.array(target)temp_dist = np.linalg.norm(p-t)if temp_dist < dist: # 如果“更近”nearest = pivot # 更新最近點dist = temp_dist # 更新最近距離max_dist = dist # 更新超球體半徑# 檢查另一個子結點對應的區域是否有更近的點temp2 = travel(further_node, target, max_dist)nodes_visited += temp2.nodes_visitedif temp2.nearest_dist < dist: # 如果另一個子結點內存在更近距離nearest = temp2.nearest_point # 更新最近點dist = temp2.nearest_dist # 更新最近距離return result(nearest, dist, nodes_visited)return travel(tree.root, point, float("inf")) # 從根節點開始遞歸 from time import time from random import randomdef random_point(k):return [random() for _ in range(k)]def random_points(k, n):return [random_point(k) for _ in range(n)]ret = find_nearest(kd, [3, 4.5]) print(ret)N = 400000 t0 = time() kd2 = KdTree(random_points(3, N))#40萬個3維點(坐標值0-1之間) ret2 = find_nearest(kd2, [0.1, 0.5, 0.8]) t1 = time() print("time: ", t1 - t0, " s") print(ret2)運行結果:40萬個點,只用了4s就搜索完畢,找到最近鄰點
Result_tuple(nearest_point=[2, 3], nearest_dist=1.8027756377319946, nodes_visited=4) time: 4.314465284347534 s Result_tuple(nearest_point=[0.10186986970329936, 0.5007753108096316, 0.7998708312483109], nearest_dist=0.002028350099282986, nodes_visited=49)4. 鳶尾花KNN分類
4.1 KNN實現
# -*- coding:utf-8 -*- # @Python Version: 3.7 # @Time: 2020/3/2 22:44 # @Author: Michael Ming # @Website: https://michael.blog.csdn.net/ # @File: 3.KNearestNeighbors.py # @Reference: https://github.com/fengdu78/lihang-code import math import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from collections import Counterclass KNearNeighbors():def __init__(self, X_train, y_train, neighbors=3, p=2):self.n = neighborsself.p = pself.X_train = X_trainself.y_train = y_traindef predict(self, X):knn_list = []# 先在訓練集中取n個點出來,計算距離for i in range(self.n):dist = np.linalg.norm(X - self.X_train[i], ord=self.p)knn_list.append((dist, self.y_train[i]))# 再在剩余的訓練集中取出剩余的,計算距離,有距離更近的,替換knn_list里最大的for i in range(self.n, len(self.X_train)):max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))dist = np.linalg.norm(X - self.X_train[i], ord=self.p)if knn_list[max_index][0] > dist:knn_list[max_index] = (dist, self.y_train[i])# 取出所有的n個最近鄰點的標簽knn = [k[-1] for k in knn_list]count_pairs = Counter(knn)# 次數最多的標簽,排序后最后一個 標簽:出現次數max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]return max_countdef score(self, X_test, y_test):right_count = 0for X, y in zip(X_test, y_test): # zip 同時遍歷多個對象label = self.predict(X)if math.isclose(label, y, rel_tol=1e-5): # 浮點型相等判斷right_count += 1print("準確率:%.4f" % (right_count / len(X_test)))return right_count / len(X_test)if __name__ == '__main__':# ---------鳶尾花K近鄰----------------iris = load_iris()df = pd.DataFrame(iris.data, columns=iris.feature_names)df['label'] = iris.targetplt.scatter(df[:50][iris.feature_names[0]], df[:50][iris.feature_names[1]], label=iris.target_names[0])plt.scatter(df[50:100][iris.feature_names[0]], df[50:100][iris.feature_names[1]], label=iris.target_names[1])plt.xlabel(iris.feature_names[0])plt.ylabel(iris.feature_names[1])data = np.array(df.iloc[:100, [0, 1, -1]]) # 取前2種花,前兩個特征X, y = data[:, :-1], data[:, -1]# 切分數據集,留20%做測試數據X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)# KNN算法,近鄰選擇20個,距離度量L2距離clf = KNearNeighbors(X_train, y_train, 20, 2)# 預測測試點,統計正確率clf.score(X_test, y_test)# 隨意給一個點,用KNN預測其分類test_point = [4.75, 2.75]test_point_flower = '測試點' + iris.target_names[int(clf.predict(test_point))]print("測試點的類別是:%s" % test_point_flower)plt.plot(test_point[0], test_point[1], 'bx', label=test_point_flower)plt.rcParams['font.sans-serif'] = 'SimHei' # 消除中文亂碼plt.rcParams['axes.unicode_minus'] = False # 正常顯示負號plt.legend()plt.show() 準確率:1.0000 測試點的類別是:測試點setosa4.2 sklearn KNN
sklearn.neighbors.KNeighborsClassifier
class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=None, **kwargs)- n_neighbors: 臨近點個數
- p: 距離度量
- algorithm: 近鄰算法,可選{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
- weights: 確定近鄰的權重
5. 文章完整代碼
# -*- coding:utf-8 -*- # @Python Version: 3.7 # @Time: 2020/3/2 22:44 # @Author: Michael Ming # @Website: https://michael.blog.csdn.net/ # @File: 3.KNearestNeighbors.py # @Reference: https://github.com/fengdu78/lihang-code import math import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from collections import Counter import timedef L_p(xi, xj, p=2):if len(xi) == len(xj) and len(xi) > 0:sum = 0for i in range(len(xi)):sum += math.pow(abs(xi[i] - xj[i]), p)return math.pow(sum, 1 / p)else:return 0class KNearNeighbors():def __init__(self, X_train, y_train, neighbors=3, p=2):self.n = neighborsself.p = pself.X_train = X_trainself.y_train = y_traindef predict(self, X):knn_list = []# 先在訓練集中取n個點出來,計算距離for i in range(self.n):dist = np.linalg.norm(X - self.X_train[i], ord=self.p)knn_list.append((dist, self.y_train[i]))# 再在剩余的訓練集中取出剩余的,計算距離,有距離更近的,替換knn_list里最大的for i in range(self.n, len(self.X_train)):max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))dist = np.linalg.norm(X - self.X_train[i], ord=self.p)if knn_list[max_index][0] > dist:knn_list[max_index] = (dist, self.y_train[i])# 取出所有的n個最近鄰點的標簽knn = [k[-1] for k in knn_list]count_pairs = Counter(knn)# 次數最多的標簽,排序后最后一個 標簽:出現次數max_count = sorted(count_pairs.items(), key=lambda x: x[1])[-1][0]return max_countdef score(self, X_test, y_test):right_count = 0for X, y in zip(X_test, y_test): # zip 同時遍歷多個對象label = self.predict(X)if math.isclose(label, y, rel_tol=1e-5): # 浮點型相等判斷right_count += 1print("準確率:%.4f" % (right_count / len(X_test)))return right_count / len(X_test)class KdNode():def __init__(self, dom_elt, split, left, right):self.dom_elt = dom_elt # k維向量節點(k維空間中的一個樣本點)self.split = split # 整數(進行分割維度的序號)self.left = left # 該結點分割超平面左子空間構成的kd-treeself.right = right # 該結點分割超平面右子空間構成的kd-treeclass KdTree():def __init__(self, data):k = len(data[0]) # 實例的向量維度def CreatNode(split, data_set):if not data_set:return Nonedata_set.sort(key=lambda x: x[split])split_pos = len(data_set) // 2 # 整除median = data_set[split_pos]split_next = (split + 1) % kreturn KdNode(median, split,CreatNode(split_next, data_set[:split_pos]),CreatNode(split_next, data_set[split_pos + 1:]))self.root = CreatNode(0, data)def preorder(self, root):if root:print(root.dom_elt)if root.left:self.preorder(root.left)if root.right:self.preorder(root.right)from collections import namedtuple# 定義一個namedtuple,分別存放最近坐標點、最近距離和訪問過的節點數 result = namedtuple("Result_tuple","nearest_point nearest_dist nodes_visited")def find_nearest(tree, point):k = len(point) # 數據維度def travel(kd_node, target, max_dist):if kd_node is None:return result([0] * k, float("inf"), 0)# python中用float("inf")和float("-inf")表示正負無窮nodes_visited = 1s = kd_node.split # 進行分割的維度pivot = kd_node.dom_elt # 進行分割的“軸”if target[s] <= pivot[s]: # 如果目標點第s維小于分割軸的對應值(目標離左子樹更近)nearer_node = kd_node.left # 下一個訪問節點為左子樹根節點further_node = kd_node.right # 同時記錄下右子樹else: # 目標離右子樹更近nearer_node = kd_node.right # 下一個訪問節點為右子樹根節點further_node = kd_node.lefttemp1 = travel(nearer_node, target, max_dist) # 進行遍歷找到包含目標點的區域nearest = temp1.nearest_point # 以此葉結點作為“當前最近點”dist = temp1.nearest_dist # 更新最近距離nodes_visited += temp1.nodes_visitedif dist < max_dist:max_dist = dist # 最近點將在以目標點為球心,max_dist為半徑的超球體內temp_dist = abs(pivot[s] - target[s]) # 第s維上目標點與分割超平面的距離if max_dist < temp_dist: # 判斷超球體是否與超平面相交return result(nearest, dist, nodes_visited) # 不相交則可以直接返回,不用繼續判斷# ----------------------------------------------------------------------# 計算目標點與分割點的歐氏距離p = np.array(pivot)t = np.array(target)temp_dist = np.linalg.norm(p - t)if temp_dist < dist: # 如果“更近”nearest = pivot # 更新最近點dist = temp_dist # 更新最近距離max_dist = dist # 更新超球體半徑# 檢查另一個子結點對應的區域是否有更近的點temp2 = travel(further_node, target, max_dist)nodes_visited += temp2.nodes_visitedif temp2.nearest_dist < dist: # 如果另一個子結點內存在更近距離nearest = temp2.nearest_point # 更新最近點dist = temp2.nearest_dist # 更新最近距離return result(nearest, dist, nodes_visited)return travel(tree.root, point, float("inf")) # 從根節點開始遞歸if __name__ == '__main__':# ---------計算距離----------------x1 = [1, 1]x2 = [5, 1]x3 = [4, 4]X = [x1, x2, x3]for i in range(len(X)):for j in range(i + 1, len(X)):for p in range(1, 5):print("x%d,x%d的L%d距離是:%.2f" % (i + 1, j + 1, p, L_p(X[i], X[j], p)))# ---------鳶尾花K近鄰----------------iris = load_iris()df = pd.DataFrame(iris.data, columns=iris.feature_names)df['label'] = iris.targetplt.scatter(df[:50][iris.feature_names[0]], df[:50][iris.feature_names[1]], label=iris.target_names[0])plt.scatter(df[50:100][iris.feature_names[0]], df[50:100][iris.feature_names[1]], label=iris.target_names[1])plt.xlabel(iris.feature_names[0])plt.ylabel(iris.feature_names[1])data = np.array(df.iloc[:100, [0, 1, -1]]) # 取前2種花,前兩個特征X, y = data[:, :-1], data[:, -1]# 切分數據集,留20%做測試數據X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)# KNN算法,近鄰選擇20個,距離度量L2距離clf = KNearNeighbors(X_train, y_train, 20, 2)# 預測測試點,統計正確率clf.score(X_test, y_test)# 隨意給一個點,用KNN預測其分類test_point = [4.75, 2.75]test_point_flower = '測試點' + iris.target_names[int(clf.predict(test_point))]print("測試點的類別是:%s" % test_point_flower)plt.plot(test_point[0], test_point[1], 'bx', label=test_point_flower)plt.rcParams['font.sans-serif'] = 'SimHei' # 消除中文亂碼plt.rcParams['axes.unicode_minus'] = False # 正常顯示負號plt.legend()plt.show()# ---------sklearn KNN----------from sklearn.neighbors import KNeighborsClassifierclf_skl = KNeighborsClassifier(n_neighbors=50, p=4, algorithm='kd_tree')start = time.time()sum = 0for i in range(100):clf_skl.fit(X_train, y_train)sum += clf_skl.score(X_test, y_test)end = time.time()print("平均準確率:%.4f" % (sum / 100))print("花費時間:%0.4f ms" % (1000 * (end - start) / 100))# ------build KD Tree--------------data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]kd = KdTree(data)kd.preorder(kd.root)# ------search in KD Tree-----------from time import timefrom random import randomdef random_point(k):return [random() for _ in range(k)]def random_points(k, n):return [random_point(k) for _ in range(n)]ret = find_nearest(kd, [3, 4.5])print(ret)N = 400000t0 = time()kd2 = KdTree(random_points(3, N))ret2 = find_nearest(kd2, [0.1, 0.5, 0.8])t1 = time()print("time: ", t1 - t0, " s")print(ret2)總結
以上是生活随笔為你收集整理的K 近邻法(K-Nearest Neighbor, K-NN)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 程序员面试金典 - 面试题 17.09.
- 下一篇: LeetCode677. 键值映射(Tr