基于线性SVM的CIFAR-10图像集分类
個人網站:紅色石頭的機器學習之路
CSDN博客:紅色石頭的專欄
知乎:紅色石頭
微博:RedstoneWill的微博
GitHub:RedstoneWill的GitHub
微信公眾號:AI有道(ID:redstonewill)
之前我用了六篇文章來詳細介紹了支持向量機SVM的算法理論和模型,鏈接如下:
1. 線性支持向量機LSVM
2. 對偶支持向量機DSVM
3. 核支持向量機KSVM
4. 軟間隔支持向量機
5. 核邏輯回歸KLR
6. 支持向量回歸SVR
實際上,支持向量機SVM確實是機器學習中一個非常重要也是非常復雜的模型。關于SVM的詳細理論和推導,本文不再闡述,讀者可以直接閱讀上面的六篇文章。
學習完了復雜的理論知識,很多朋友可能非常想通過一個實際的例子,動手編寫出一個SVM程序,應用到實際中。那么本文就將帶領大家動手寫出自己的SVM程序,并且應用到圖像的分類問題中。我們將在經典的CIFAR10圖像數據集上進行SVM程序驗證。
話不多說,正式開始!
1. SVM的基本思想
簡單來說,支持向量機SVM就是在特征空間中找到一條最佳的分類超平面,能夠讓正、負樣本距離該超平面的間隔(margin)最大化。
以二維平面為例,確定一條直線對正負樣本進行分類,如下圖所示:
很明顯,雖然分類線H1、H2、H3都能夠將正負樣本完全分開,但是毫無疑問H3更好一些。原因是正負樣本距離H3都足夠遠,即間隔「margin」最大。這就是SVM的基本思想:盡量讓所有樣本距離分類超平面越遠越好。
2. 線性分類與得分函數
在線性分類器算法中,輸入為x,輸出為y,令權重系數為W,常數項系數為b。我們定義得分函數s為:
s=Wx+bs=Wx+b
這是線性分類器的一般形式,得分函數s所屬類別值越大,表示預測該類別的概率越大。
以圖像識別為例,共有3個類別「cat,dog,ship」。令輸入x的特征維度為4「即包含4個像素值」,W的維度是3x4,b的維度是3x1。在W和b確定后,得到各個類別的得分函數s為:
由上圖可知,因為總有3個類別,得分函數s是3x1的向量。其中,cat score=-96.8,dog score=437.9,ship score=61.95。從s的值來說,dog score最高,cat score最低,則預測為狗的概率更大一些。而該圖片真實標簽是一只貓,顯然,從得分函數s上來看,該線性分類器的預測結果是錯誤的。
通常為了簡化計算,我們直接將W和b整合成一個矩陣,同時將x額外增加一個全為1的維度。這樣,得分函數s的表達式得到了簡化:
W:=[W??b]W:=[Wb]
x:=[x;?1]x:=[x;1]
s=Wxs=Wx
示例圖如下:
3. 優化策略與損失函數
通常來說,SVM的優化策略是樣本到分類超平面的距離最大化。也就是說盡量讓正負樣本距離分類超平面有足夠寬的間隔,這是基于距離的衡量優化方式。針對上文提到的例子,圖片真實標簽是一只貓,但是得到的s值卻是最低的,顯然這不是我們希望看到的。最好的情況應該是cat score最高。這樣才能保證預測cat的概率更大。此時,利用SVM的間隔最大化的思想,就要求cat score不僅僅要大于其它類別的s值,而且要達到一定的程度,可以說有個最低閾值。
因此,這種新的SVM優化策略可以這樣理解:正確類別對應的得分函數s應該比其它類別的得分函數s大一個閾值 ΔΔ:
syi≥sj+Δsyi≥sj+Δ
接下來,我們就可以根據這種思想定義SVM的損失函數:
Li=∑j≠yimax(0,sj?syi+Δ)Li=∑j≠yimax(0,sj?syi+Δ)
其中,yiyi表示正確的類別,j表示錯誤類別。從LiLi的表達式可以看出,只有當syisyi比sjsj大超過閾值 ΔΔ 時,LiLi才為零,否則LiLi大于零。這種策略類似于距離最大化策略。
舉個例子來解釋LiLi的計算過程:例如得分函數s=[-1, 5, 4],y1y1是真實樣本,令Δ=3Δ=3,則:
Li=max(0,?1?5+3)+max(0,4?5+3)=0+2=2Li=max(0,?1?5+3)+max(0,4?5+3)=0+2=2
該損失函數由兩部分組成:y1y1與y0y0,y1y1與y2y2。由于y1y1與y0y0的差值大于閾值 ΔΔ,則其損失函數為0;雖然y1y1比y2y2大,但差值小于閾值 ΔΔ,則計算得到其損失函數為2。總的損失函數即為2。
這類損失函數的表達式一般稱作合頁損失函數「Hinge Loss Function」:
顯然,只有當sj?syi+Δ<0sj?syi+Δ<0 時,損失函數才為零。
這種合頁損失函數的優點是體現了SVM距離最大化的思想;而且,損失函數大于零時,是線性函數,便于梯度下降算法求導。
除了這種線性hinge loss SVM之外,還有squared hinge loss SVM,即采用平方的形式:
Li=∑j≠yimax(0,sj?syi+Δ)2Li=∑j≠yimax(0,sj?syi+Δ)2
這種squared hinge loss SVM與linear hinge loss SVM相比較,特點是對違背間隔閾值要求的點加重懲罰,違背的越大,懲罰越大。某些實際應用中,squared hinge loss SVM的效果更好一些。具體使用哪個,可以根據實際問題,進行交叉驗證再確定。
對于超參數閾值 ΔΔ,一般設置 Δ=1Δ=1。因為,權重系數W是可伸縮的,直接影響著得分函數s的大小。所以說,Δ=1Δ=1 或 Δ=10Δ=10,實際上沒有差別,對W的伸縮完全可以抵消掉 ΔΔ 的數值影響。因此,通常把 ΔΔ 設置為1即可。此時的損失函數為:
Li=∑j≠yimax(0,sj?syi+1)Li=∑j≠yimax(0,sj?syi+1)
SVM中,為了防止模型過擬合,可以使用正則化「Regularization」方法。例如使用L2正則化:
R(W)=∑k∑lw2k,lR(W)=∑k∑lwk,l2
引入正則化項之后的損失函數為:
L=1NLi+λR(W)L=1NLi+λR(W)
其中,N是訓練樣本個數,λλ 是正則化參數,可調。一般來說,λλ 越大,對權重W的懲罰越大;λλ 越小,對權重W的懲罰越小。λλ 實際上是權衡損失函數第一項和第二項之間的關系:λλ 越大,對W的懲罰更大,犧牲正負樣本之間的間隔,可能造成欠擬合「underfit」;λλ 越小,得到的正負樣本間隔更大,但是W數值會變大,可能造成過擬合「overfit」。實際應用中,可通過交叉驗證,選擇合適的正則化參數λλ。
常數項b是否需要正則化?其實一般b是否正則化對模型的影響很小。可以對b進行正則化,也可以選擇不。實際應用中,通常只對權重系數W進行正則化。
4. 線性SVM實戰
首先,簡單介紹一下我們將要用到的經典數據集:CIFAR-10。
CIFAR-10數據集由60000張3×32×32的 RGB 彩色圖片構成,共10個分類。50000張訓練,10000張測試(交叉驗證)。這個數據集最大的特點在于將識別遷移到了普適物體,而且應用于多分類,是非常經典和常用的數據集。
這個數據集網上可以下載,我直接給大家下好了,放在云盤里,需要的自行領取。
鏈接:https://pan.baidu.com/s/1iZPwt72j-EpVUbLKgEpYMQ
密碼:vy1e
下面的代碼是隨機選擇每種類別下的5張圖片并顯示:
# Visualize some examples from the dataset. # We show a few examples of training images from each class. classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] num_classes = len(classes) samples_per_class = 7 for y, cls in enumerate(classes):idxs = np.flatnonzero(y_train == y)idxs = np.random.choice(idxs, samples_per_class, replace=False)for i, idx in enumerate(idxs):plt_idx = i * num_classes + y + 1plt.subplot(samples_per_class, num_classes, plt_idx)plt.imshow(X_train[idx].astype('uint8'))plt.axis('off')if i == 0:plt.title(cls) plt.show()接下來,就是對SVM計算hinge loss,包含L2正則化,代碼如下:
scores = X.dot(W) correct_class_score = scores[range(num_train), list(y)].reshape(-1,1) # (N,1) margin = np.maximum(0, scores - correct_class_score + 1) margin[range(num_train), list(y)] = 0 loss = np.sum(margin) / num_train + 0.5 * reg * np.sum(W * W)計算W梯度的代碼如下:
num_classes = W.shape[1] inter_mat = np.zeros((num_train, num_classes)) inter_mat[margin > 0] = 1 inter_mat[range(num_train), list(y)] = 0 inter_mat[range(num_train), list(y)] = -np.sum(inter_mat, axis=1)dW = (X.T).dot(inter_mat) dW = dW/num_train + reg*W根據SGD算法,每次迭代后更新W:
W -= learning_rate * dW訓練過程中,使用交叉驗證的方法選擇最佳的學習因子 learning_rate 和正則化參數 reg,代碼如下:
learning_rates = [1.4e-7, 1.5e-7, 1.6e-7] regularization_strengths = [8000.0, 9000.0, 10000.0, 11000.0, 18000.0, 19000.0, 20000.0, 21000.0]results = {} best_lr = None best_reg = None best_val = -1 # The highest validation accuracy that we have seen so far. best_svm = None # The LinearSVM object that achieved the highest validation rate.for lr in learning_rates:for reg in regularization_strengths:svm = LinearSVM()loss_history = svm.train(X_train, y_train, learning_rate = lr, reg = reg, num_iters = 2000)y_train_pred = svm.predict(X_train)accuracy_train = np.mean(y_train_pred == y_train)y_val_pred = svm.predict(X_val)accuracy_val = np.mean(y_val_pred == y_val)if accuracy_val > best_val:best_lr = lrbest_reg = regbest_val = accuracy_valbest_svm = svmresults[(lr, reg)] = accuracy_train, accuracy_valprint('lr: %e reg: %e train accuracy: %f val accuracy: %f' %(lr, reg, results[(lr, reg)][0], results[(lr, reg)][1])) print('Best validation accuracy during cross-validation:\nlr = %e, reg = %e, best_val = %f' %(best_lr, best_reg, best_val))訓練結束后,選擇最佳的學習因子 learning_rate 和正則化參數 reg,在測試圖片集上進行驗證,代碼如下:
# Evaluate the best svm on test set y_test_pred = best_svm.predict(X_test) test_accuracy = np.mean(y_test == y_test_pred) print('linear SVM on raw pixels final test set accuracy: %f' % test_accuracy)linear SVM on raw pixels final test set accuracy: 0.384000
最后,有個比較好玩的操作,我們可以將訓練好的權重W可視化:
# Visualize the learned weights for each class. # Depending on your choice of learning rate and regularization strength, these may # or may not be nice to look at. w = best_svm.W[:-1,:] # strip out the bias w = w.reshape(32, 32, 3, 10) w_min, w_max = np.min(w), np.max(w) classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] for i in range(10):plt.subplot(2, 5, i + 1)# Rescale the weights to be between 0 and 255wimg = 255.0 * (w[:, :, :, i].squeeze() - w_min) / (w_max - w_min)plt.imshow(wimg.astype('uint8'))plt.axis('off')plt.title(classes[i])可以明顯看出,由W重構的圖片具有所屬樣本類別相似的地方,這正是線性SVM學習到的東西。
5. 總結
本文講述的線性SVM利用距離間隔最大的思想,利用hinge loss的優化策略,來構建一個機器學習模型,并將這個簡單模型應用到CIFAR-10圖片集中進行訓練和測試。實際測試的準確率在40%左右。準確率雖然不是很高,但是此SVM是線性模型,沒有引入核函數構建非線性模型,也沒有使用AlexNet,VGG,GoogLeNet,ResNet等卷積網絡。測試結果比隨機猜測10%要好很多,是一個不錯的可實操的有趣模型。
完整代碼,點擊「源碼」獲取。
源碼
參考資料:
http://cs231n.github.io/linear-classify/
總結
以上是生活随笔為你收集整理的基于线性SVM的CIFAR-10图像集分类的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Python 3.8 稳定版正式发布,新
- 下一篇: 控件的WM_NOTIFY消息映射