【图像分类】 基于Pytorch的多类别图像分类实战
歡迎大家來到圖像分類專欄,本篇基于Pytorch完成一個多類別圖像分類實戰(zhàn)。
作者?| 郭冰洋
編輯 | 言有三
1 簡介
實現一個完整的圖像分類任務,大致需要分為五個步驟:
1、選擇開源框架
目前常用的深度學習框架主要包括tensorflow、caffe、pytorch、mxnet等;
2、構建并讀取數據集
根據任務需求搜集相關圖像搭建相應的數據集,常見的方式包括:網絡爬蟲、實地拍攝、公共數據使用等。隨后根據所選開源框架讀取數據集。
3、框架搭建
選擇合適的網絡模型、損失函數以及優(yōu)化方式,以完成整體框架的搭建
4、訓練并調試參數
通過訓練選定合適超參數
5、測試準確率
在測試集上驗證模型的最終性能
本文利用Pytorch框架,按照上述結構實現一個基本的圖像分類任務,并詳細闡述其中的細節(jié)及注意事項。
2 數據集
本次實戰(zhàn)選擇的數據集為Kaggle競賽中的細胞數據集,共包含9961個訓練樣本,2491個測試樣本,可以分為嗜曙紅細胞、淋巴細胞、單核細胞、中性白細胞4個類別,圖片大小為320x240。
Pytorch中封裝了相應的數據讀取的類函數,通過調用torch.utils.data.Datasets函數,則可以實現讀取功能。
__init__()模塊用來定義相關的參數,__len__()模塊用來獲取訓練樣本個數,__getitem__()模塊則用來獲取每張具體的圖片,在讀取圖片時其可以通過opencv庫、PIL庫等進行讀取,具體代碼如下:
# 數據集
class dataset(data.Dataset):
? ?# 參數預定義
此外,需要定義圖像增強模塊,即上述代碼中的transform,通常采取的操作為翻轉、剪切等,關于圖像增強的具體介紹可以參考公眾號前作。
【技術綜述】深度學習中的數據增強方法都有哪些?
需要特別強調的是對圖像進行去均值處理,很多同學不明白為何要減去均值,其主要的原因是圖像作為一種平穩(wěn)的數據分布,通過減去數據對應維度的統(tǒng)計平均值,可以消除公共部分,以凸顯個體之間的特征和差異。進行去均值前后操作后的圖像對比如下:
3 框架搭建
本次實戰(zhàn)主要選取了VGG16、Resnet50、InceptionV4三個經典網絡,也是對前篇文章的一個總結。
損失函數則選擇交叉熵損失函數:【技術綜述】一文道盡softmax loss及其變種
優(yōu)化方式選擇SGD、Adam優(yōu)化兩種:【模型訓練】SGD的那些變種,真的比SGD強嗎
完整代碼獲取方式:發(fā)送關鍵詞“多類別分類”給公眾號
4 訓練及參數調試
初始學習率設置為0.01,batch size設置為8,衰減率設置為0.00001,迭代周期為15,在不同框架組合下的最佳準確率和最低loss如下圖所示:
可以發(fā)現在驗證集上Resnet-50+SGD+Cross Entropy的組合下取得了99%左右的準確率,相反VGG-16結果則稍微差一些。
最佳組合下的準確率走勢曲線如下圖所示:
5 測試
對上述模型分別在測試集上進行測試,所獲得的結果如下圖所示,整體精度比訓練集上約下降了一個百分點:
關于代碼,可以參考有三AI開源的12大深度學習開源框架使用的項目:
【完結】給新手的12大深度學習開源框架快速入門項目
總結
以上就是整個多類別圖像分類實戰(zhàn)的過程,由于時間限制,本次實戰(zhàn)并沒有對多個數據集進行訓練,因此沒有列出同一模型在不同數據集上的表現。
有三AI夏季劃
有三AI夏季劃進行中,歡迎了解并加入,系統(tǒng)性成長為中級CV算法工程師。
轉載文章請后臺聯(lián)系
侵權必究
往期精選
【技術綜述】你真的了解圖像分類嗎?
【技術綜述】多標簽圖像分類綜述
【圖像分類】分類專欄正式上線啦!初入CV、AI你需要一份指南針!
總結
以上是生活随笔為你收集整理的【图像分类】 基于Pytorch的多类别图像分类实战的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【图像分割应用】设备自动化(一)——自动
- 下一篇: 【图像分类】从数据集和经典网络开始