什么是k-NN算法?怎样实现?终于有人讲明白了
導讀:使用分類模型預測類標簽。
作者:阿迪蒂亞·夏爾馬(Aditya Sharma)、維什韋什·拉維·什里馬利(Vishwesh Ravi Shrimali)、邁克爾·貝耶勒(Michael Beyeler)
來源:大數據DT(ID:hzdashuju)
以蘭普威爾小鎮為例,那里的人們為他們的兩支球隊——蘭普威爾紅隊和蘭普威爾藍隊——而瘋狂。紅隊已經存在很長時間了,人們很喜歡這支隊伍。
但是后來,一些外地來的富翁買下了紅隊的最佳射手,成立了一支新的球隊——藍隊。令多數紅隊球迷不滿的是,這位最佳射手將繼續帶領藍隊奪得冠軍。多年后,盡管一些球迷對他早期的職業選擇強烈不滿,但他還是回到了紅隊。可是不管怎么說,你會明白為什么紅隊的球迷和藍隊的球迷一直不能和睦相處。
事實上,這兩隊的球迷是如此分裂,以至于他們從未在同一處居住過。我們甚至聽說過這樣的故事:當藍隊球迷搬到隔壁時,紅隊球迷就會故意離開。故事是真實的!
不管怎樣,我們是新到鎮上的,我們正挨家挨戶向人們推銷藍隊產品。然而,我們偶爾會遇到心在滴血的紅隊球迷因為我們推銷藍隊的東西而對我們大吼大叫,還把我們趕出他們的草坪。太不友好了!完全避開這些紅隊球迷,而只拜訪藍隊球迷,這樣壓力會小很多,我們的時間也能更好地被利用。
我們相信可以預測紅隊球迷的生活區,開始記錄我們的活動軌跡。如果我們路過紅隊球迷的家,則會在手邊的城鎮地圖上畫一個三角形;否則會畫一個正方形。一段時間后,我們對每個人的居住地有了一個很好的了解,如圖3-3所示。
▲圖3-3 在地圖中標記紅隊和藍隊球迷居住地
可是,在圖3-3中,我們正在靠近一間標記為綠色圓圈的房子。我們應該敲他們的門嗎?我們試圖找到一些線索,以確定他們可能是哪個隊的球迷(也許在后門廊上掛著隊旗,可我們沒看到)。我們怎樣才能知道敲他們的門是安全的呢?
這個例子恰恰描述了監督學習算法可以解決的問題。我們有一堆觀測數據(房子、位置以及顏色),這些數據構成了我們的訓練數據。我們可以利用這些數據從經驗中學習,當我們要對一個新房子進行顏色預測的任務時,我們就可以做出明智的估計。
正如前面說過的那樣,紅隊球迷對他們的球隊充滿感情,所以他們永遠不會和藍隊球迷住在一起。我們能不能利用這些信息,觀察一下周圍的房子,再看看新房子里住的是哪個隊的球迷?
這正是k-NN算法能夠實現的。
01 理解k-NN算法
k-NN算法可以說是機器學習算法中最簡單的一個。原因是我們基本上只需要存儲訓練數據集。然后,要預測一個新的數據點,我們只需要找到訓練數據集中最近的數據點:它的最近鄰居。
簡而言之,k-NN算法認為一個數據點可能與其鄰居屬于同一類。想想看,如果我們的鄰居是紅隊球迷,我們可能也是紅隊球迷;否則,我們早就搬走了。對于藍隊球迷來說也是如此。
當然,有些鄰居可能稍微有點復雜。在這種情況下,我們可能不只要考慮我們的最近鄰居(k=1),而且還要考慮離我們最近的k個最近鄰居。讓我們繼續前面介紹過的例子,如果我們是紅隊球迷,我們不可能搬到大多數人都認為可能是藍隊球迷的社區。
這就是它的全部。
02 用OpenCV實現k-NN
使用OpenCV,通過cv2.ml.KNearest_Create()函數我們可以很容易創建一個k-NN模型。構建模型包括下列步驟:
生成一些訓練數據。
對于一個給定的數k,創建一個k-NN對象。
為我們要分類的一個新數據點找到k個最近鄰。
根據多數票分配新數據點的類標簽。
繪制結果。
首先,我們導入所有必要的模塊:OpenCV的k-NN算法模塊、NumPy的數據處理模塊、Matplotlib的繪圖模塊。如果你正在使用Jupyter Notebook,請不要忘記調用%matplotlib inline魔術命令:
import?numpy?as?np import?cv2import?matplotlib.pyplot?as?plt %matplotlib?inlineplt.style.use('ggplot')1. 生成訓練數據
第一步是生成一些訓練數據。為此,我們將使用NumPy的隨機數生成器。我們將固定隨機數生成器的種子,這樣重新運行腳本總是可以生成相同的值:
np.random.seed(42)好了,現在讓我們開始吧。我們的訓練數據應該是什么樣子的呢?
在前面的例子中,每個數據點都是城鎮地圖上的一個房子。每個數據點都有兩個特征(即數據點在城鎮地圖上的位置坐標x和y)以及一個類標簽(即藍隊球迷居住地是一個藍色方塊,紅隊球迷居住地是一個紅色三角形)。
因此,單個數據點的特征在城鎮地圖上可以用x和y坐標的一個二元向量來表示。類似地,如果是一個藍色方塊,那么標簽是0;如果是一個紅色三角形,那么標簽是1。這個過程包括數據點生成、數據點繪制以及新數據點的標簽預測。讓我們來看看如何實現這些步驟:
1)隨機選擇地圖上的位置以及一個隨機標簽(0或者1),我們可以生成單個數據點。假設城鎮地圖的范圍是0≤x≤100和0≤y≤100。那么,我們可以生成一個隨機數據點,如下所示:
single_data_point?=?np.random.randint(0,?100,?2) single_data_pointOut:
array([51,?92])在上述輸出中我們可以看到,這將在0到100之間選擇兩個隨機整數。我們把第一個整數解釋為地圖上數據點的x坐標,第二個整數解釋為數據點的y坐標。
2)類似地,我們為數據點選擇一個標簽:
single_label?=?np.random.randint(0,?2) single_labelOut:
0這個數據點的類是0,將其解釋為一個藍色方塊。
3)讓我們將這個過程封裝到一個函數中,該函數以生成的數據點數(即num_samples)和每個數據點的特征數(即num_features)作為輸入:
def?generate_data(num_samples,?num_features=2):"""Randomly?generates?a?number?of?data?points"""因為在我們的例子中,特征數是2,所以可以使用這個數作為默認的參數值。這樣,如果我們在調用函數時,沒有顯式地指定num_features,那么會將一個為2的值自動分配該函數。我相信你現在已經明白了。
我們要創建的數據矩陣應該有num_samples行num_features列,而且矩陣中的每個元素都應該是從(0, 100)范圍內隨機選取的一個整數:
data_size?=?(num_samples,?num_features)train_data?=?np.random.randint(0,?100,?size=data_size)類似地,我們要創建一個向量,包含(0, 2)范圍內的一個隨機整數標簽,對于所有樣本:
labels_size?=?(num_samples,?1)labels?=?np.random.randint(0,?2,?size=labels_size)不要忘記讓函數返回生成的數據:
return?train_data.astype(np.float32),?labels提示:在涉及數據類型時,OpenCV可能有點挑剔,因此一定要將數據點轉換成np.float32!
4)讓我們對該函數進行測試并生成任意數量的數據點,假設為11個數據點,其坐標是隨機選擇的:
train_data,?labels?=?generate_data(11) train_dataOut:
array([[71.,?60.],[20.,?82.],[86.,?74.],[74.,?87.],[99.,?23.],[?2.,?21.],[52.,??1.],[87.,?29.],[37.,??1.],[63.,?59.],[20.,?32.]],?dtype=float32)5)正如我們在上述輸出中看到的那樣,train_data變量是一個11×2的數組,每一行對應一個數據點。通過在數組中建立索引來查看第一個數據點及其對應的標簽:
train_data[0],?labels[0]Out:
(array([71.,?60.],?dtype=float32),?array([1]))6)這就告訴我們第一個數據點是一個紅色三角形(因為它的類是1),在城鎮地圖上的位置是(x, y)=(71, 60)。如果需要,我們可以使用Matplotlib繪制城鎮地圖上的這個數據點:
plt.plot(train_data[0,?0],?train_data[0,?1],?color='r',?marker='^',?markersize=10) plt.xlabel('x?coordinate') plt.ylabel('y?coordinate')我們得到的結果如圖3-4所示。
▲圖3-4 生成第一個數據點及其標簽
7)但是,如果我們想一次看到整個訓練集呢?讓我們為此編寫一個函數。應該把所有藍色方塊數據點的列表(all_blue)以及所有紅色三角形數據點的列表(all_red)作為函數的輸入:
def?plot_data(all_blue,?all_red):8)我們的函數應該把所有的藍色數據點繪制成藍色方塊(使用顏色“b”和標記“s”),這可以使用matplotlib的scatter函數來實現。為了使其可以工作,我們必須以一個N×2的數組形式傳遞藍色數據點,其中N是樣本數。然后,all_blue [:, 0]包含數據點的所有x坐標,all_blue[:, 1]包含數據點的所有y坐標:
plt.figure(figsize=(10,?6))plt.scatter(all_blue[:,?0],?all_blue[:,?1],?c='b',?marker='s',?s=180)9)類似地,所有的紅色數據點也可以這樣實現:
plt.scatter(all_red[:,?0],?all_red[:,?1],?c='r',?marker='^',?s=180)10)最后,我們用標簽標注圖:
plt.xlabel('x?coordinate?(feature?1)')plt.ylabel('y?coordinate?(feature?2)')Out:
array([False,?False,?False,??True,?False,??True,??True,??True,??True,True,?False])11)讓我們在數據集上試試看!首先,我們必須將所有的數據點拆分成紅色數據集和藍色數據集。使用下列命令,我們可以快速選擇前面創建的label數組中所有等于0的元素(ravel平展數組):
labels.ravel()?==?012)所有藍色數據點是之前創建的train_data數組的所有行,對應的標簽是0:
blue?=?train_data[labels.ravel()?==?0]13)對于所有的紅色數據點也可以這樣實現:
red?=?train_data[labels.ravel()?==?1]14)最后,讓我們繪制所有的數據點:
plot_data(blue,?red)創建的圖如圖3-5所示。
▲圖3-5 生成所有數據點
現在是時候訓練分類器了。
2. 訓練分類器
與機器學習的所有其他函數一樣,k-NN分類器是OpenCV 3.1 ml模塊的一部分。使用下列命令,我們可以創建一個新的分類器:
knn?=?cv2.ml.KNearest_create()提示:在OpenCV的老版本中,這個函數被稱為cv2.KNearest( )。
然后,我們將訓練數據傳遞給train方法:
knn.train(train_data,?cv2.ml.ROW_SAMPLE,?labels)Out:
True此處,我們必須讓knn知道我們的數據是一個N×2的數組(即每一行是一個數據點)。成功后,函數返回True。
3. 預測一個新數據點的標簽
knn提供的另一個非常有用的方法是findNearest。該方法可以基于其最近鄰居預測一個新數據點的標簽。
generate_data函數生成一個新的數據點實際上是很容易的!我們可以把一個新數據點看成大小為1的數據集:
newcomer,?_?=?generate_data(1) newcomerOut:
array([[91.,?59.]],?dtype=float32)我們的函數還會返回一個隨機標簽,可是我們對此并不感興趣。我們想用已訓練的分類器來預測!我們可以讓Python忽略一個帶有下劃線(_)的輸出值。
讓我們再來看看我們的城鎮地圖。我們將像前面那樣繪制訓練集,而且還將新數據點添加為一個綠色圓圈(因為我們還不知道這個數據點應該是藍色方塊還是紅色三角形):
plot_data(blue,?red) plt.plot(newcomer[0,?0],?newcomer[0,?1],?'go',?markersize=14);提示:你可以向plt.plot函數調用添加一個分號來抑制其輸出,與Matlab中的一樣。
上述代碼將生成圖3-6(–環)。
▲圖3-6 生成的結果圖
如果你必須根據該數據點的鄰居來猜測的話,你會為新數據點分配什么標簽?藍色方塊,還是紅色三角形?
這要看情況,不是嗎?如果我們查看離該點最近的房屋(大概在(x, y)=(85, 75),在圖3-6中的虛線圓圈內),我們可能也會給新數據點分配一個三角形。這也正好是我們的分類器所預測的k=1:
ret,?results,?neighbor,?dist?=?knn.findNearest(newcomer,?1) print("Predicted?label:\t",?results) print("Neighbor's?label:\t",?neighbor) print("Distance?to?neighbor:\t",?dist)Out:
Predicted?label:?????[[1.]] Neighbor's?label:?????[[1.]] Distance?to?neighbor:?????[[250.]]這里,knn報告最近鄰居是250個任意單位距離,這個鄰居標簽是1(我們說過它對應于紅色三角形),因此,新數據點也應該標記為1。如果我們看看k=2的最近鄰居和k=3的最近鄰居,情況也是一樣的。但我們要注意不要令k為偶數,這是為什么呢?在圖3-6中(虛線圓圈)可以看到原因,在虛線圓圈內的6個最近鄰居中,有3個藍色方塊,3個紅色三角形—打平了!
提示:在平局情況下,OpenCV的k-NN實現將更喜歡與數據點的總體距離更近的鄰居。
最后,如果我們擴大搜索窗口,根據k=7的最近鄰居對新數據點進行分類,結果會怎樣呢(圖3-6中的實線圓圈)?
我們通過調用findNearest方法、k=7的鄰居找出答案:
ret,?results,?neighbor,?dist?=?knn.findNearest(newcomer,?7) print("Predicted?label:\t",?results) print("Neighbor's?label:\t",?neighbor) print("Distance?to?neighbor:\t",?dist)Out:
Predicted?label:?????[[0.]] Neighbor's?label:?????[[1.?1.?0.?0.?0.?1.?0.]] Distance?to?neighbor:?????[[?250.??401.??784.??916.?1073.?1360.?4885.]]此時,預測標簽變成了0(藍色方塊)。原因是,現在我們在實線圓圈內有4個鄰居是藍色方塊(標簽0),只有3個鄰居是紅色三角形(標簽1)。因此,多數票表明這個新數據點也應該是一個藍色方塊。
或者,可以使用predict方法進行預測。但是,首先我們需要設置k:
knn.setDefaultK(1) knn.predict(newcomer)Out:
(1.0,?array([[1.]],?dtype=float32))如果我們設置k=7會怎樣呢?讓我們來看看吧:
knn.setDefaultK(7) knn.predict(newcomer)Out:
(0.0,?array([[0.]],?dtype=float32))正如你所看到的,k-NN的結果隨k值的變化而變化。但是,通常我們事先并不知道k取什么值最合適。對于這個問題,最簡單的解決方案是嘗試一系列k值,看看哪個值表現最佳。
關于作者:阿迪蒂亞·夏爾馬(Aditya Sharma),羅伯特·博世(Robert Bosch)公司的一名高級工程師,致力于解決真實世界的自動計算機視覺問題。曾獲得羅伯特·博世公司2019年人工智能編程馬拉松的首名。
維什韋什·拉維·什里馬利(Vishwesh Ravi Shrimali),于2018年畢業于彼拉尼博拉理工學院(BITS Pilani)機械工程專業。此后一直在BigVision LLC從事深度學習和計算機視覺方面的工作,還參與了官方OpenCV課程的創建。
邁克爾·貝耶勒(Michael Beyeler),是華盛頓大學神經工程和數據科學的博士后研究員,致力于仿生視覺的計算模型研究,以為盲人植入人工視網膜(仿生眼睛),改善盲人的感知體驗。他的工作屬于神經科學、計算機工程、計算機視覺和機器學習的交叉領域。
本文摘編自《機器學習:使用OpenCV、Python和scikit-learn進行智能圖像處理(原書第2版)》(ISBN:978-7-111-66826-8),經出版方授權發布。
延伸閱讀《機器學習》(原書第2版)
點擊上圖了解及購買
轉載請聯系微信:DoctorData
推薦語:一本基于OpenCV4和Python的機器學習實戰手冊,既詳細介紹機器學習及OpenCV相關的基礎知識,又通過具體實例展示如何使用OpenCV和Python實現各種機器學習算法,并提供大量示例代碼,可以幫助你掌握機器學習實用技巧,解決各種不同的機器學習和圖像處理問題。
劃重點👇
干貨直達👇
詳解數據管理發展的5個階段
解讀OpenShift的邏輯架構和技術架構
元宇宙時代,技術長什么樣
這5種思維模式,大牛產品經理都在用
更多精彩👇
在公眾號對話框輸入以下關鍵詞
查看更多優質內容!
讀書?|?書單?|?干貨?|?講明白?|?神操作?|?手把手
大數據?|?云計算?|?數據庫?|?Python?|?爬蟲?|?可視化
AI?|?人工智能?|?機器學習?|?深度學習?|?NLP
5G?|?中臺?|?用戶畫像?|?數學?|?算法?|?數字孿生
據統計,99%的大咖都關注了這個公眾號
👇
總結
以上是生活随笔為你收集整理的什么是k-NN算法?怎样实现?终于有人讲明白了的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: WF动态挂单(1)
- 下一篇: JavaEE实战班第九天