yolov5训练_YoloV5模型训练实战教程:Kaggle全球小麦检测竞赛
寫(xiě)在前面
前段時(shí)間參加了Kaggle的一個(gè)目標(biāo)檢測(cè)競(jìng)賽,比賽后期因?yàn)楣ぷ鬏^繁忙就擱置了,但仍然獲得了銅牌(前10%)。因此在這里想跟大家分享下自己的方案,希望能幫助大家更好的了解目標(biāo)檢測(cè)這一經(jīng)典的計(jì)算機(jī)視覺(jué)領(lǐng)域。
這篇教程的主要代碼來(lái)源于這個(gè)git倉(cāng)庫(kù)(https://github.com/ultralytics/yolov5),是國(guó)外一個(gè)公司開(kāi)源的。選擇這個(gè)項(xiàng)目一是因?yàn)樾阅芎?#xff0c;最新mAP達(dá)到了50.8/25.5ms,太強(qiáng)大了;二是因?yàn)樵擁?xiàng)目是用pytorch實(shí)現(xiàn)的,使用門(mén)檻低,很適合初學(xué)者。下面開(kāi)始我們的實(shí)戰(zhàn)教程。
YoloV5
數(shù)據(jù)分析
這里需要重點(diǎn)說(shuō)明下,CV任務(wù)的第一步絕對(duì)不是搭模型,而是觀察數(shù)據(jù),只有了解的數(shù)據(jù)的組成和分布,才能搭出性能更好好的模型。首先看下比賽數(shù)據(jù),看下面9張圖片,可以看出小麥品種不一,風(fēng)格差異很大,所以很明顯Domain Gap是這個(gè)比賽的難點(diǎn)。
小麥數(shù)據(jù)
再來(lái)看下標(biāo)注框,每張圖有幾十個(gè)目標(biāo),分布非常密集,所以這個(gè)任務(wù)其實(shí)屬于密集小目標(biāo)檢測(cè)問(wèn)題,因此像FPN這種金字塔模型肯定是必不可少的。
數(shù)據(jù)標(biāo)注展示
數(shù)據(jù)處理
目標(biāo)檢測(cè)任務(wù)跟分類(lèi)不同,數(shù)據(jù)的格式有很多種,比較常用的是COCO和VOC格式,YoloV5使用的是YOLO自己的格式。Kaggle的數(shù)據(jù)標(biāo)簽是用csv格式保存的,需要轉(zhuǎn)換成YOLO的標(biāo)注格式。格式轉(zhuǎn)換可以參考下面這段代碼。
import numpy as np # linear algebraimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)import osdf = pd.read_csv('../input/global-wheat-detection/train.csv')bboxs = np.stack(df['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=',')))for i, column in enumerate(['x', 'y', 'w', 'h']): df[column] = bboxs[:,i]df.drop(columns=['bbox'], inplace=True)df['x_center'] = df['x'] + df['w']/2df['y_center'] = df['y'] + df['h']/2df['classes'] = 0from tqdm.auto import tqdmimport shutil as shdf = df[['image_id','x', 'y', 'w', 'h','x_center','y_center','classes']]index = list(set(df.image_id))index = list(set(df.image_id))source = 'train'if True: for fold in [0]: val_index = index[len(index)*fold//5:len(index)*(fold+1)//5] for name,mini in tqdm(df.groupby('image_id')): if name in val_index: path2save = 'val2017/' else: path2save = 'train2017/' if not os.path.exists('convertor/fold{}/labels/'.format(fold)+path2save): os.makedirs('convertor/fold{}/labels/'.format(fold)+path2save) with open('convertor/fold{}/labels/'.format(fold)+path2save+name+".txt", 'w+') as f: row = mini[['classes','x_center','y_center','w','h']].astype(float).values row = row/1024 row = row.astype(str) for j in range(len(row)): text = ' '.join(row[j]) f.write(text) f.write("") if not os.path.exists('convertor/fold{}/images/{}'.format(fold,path2save)): os.makedirs('convertor/fold{}/images/{}'.format(fold,path2save)) sh.copy("../input/global-wheat-detection/{}/{}.jpg".format(source,name),'convertor/fold{}/images/{}/{}.jpg'.format(fold,path2save,name))模型訓(xùn)練
先看下YoloV5長(zhǎng)啥樣,畢竟從YoloV1發(fā)展到V5,模型也確實(shí)復(fù)雜了很多,不過(guò)其實(shí)核心結(jié)構(gòu)沒(méi)有變,只是增加了大量的trick,每一個(gè)trick都需要大量的調(diào)參,這也是YoloV5作者的一個(gè)非常大的貢獻(xiàn)。不過(guò)我建議剛上手不用了解的這么深入,先跑起來(lái)再說(shuō)。
YoloV5網(wǎng)絡(luò)結(jié)構(gòu)圖
數(shù)據(jù)處理完畢后,下載YoloV5的源代碼,在data文件下配置自己的數(shù)據(jù)路徑,例如下面的wheat0.yaml,主要是訓(xùn)練和驗(yàn)證集的路徑,類(lèi)別數(shù)量和類(lèi)別名這4個(gè)參數(shù)。
# train and val datasets (image directory or *.txt file with image paths)train: ./convertor/fold0/images/train2017/val: ./convertor/fold0/images/val2017/# number of classesnc: 1# class namesnames: ['wheat']配置文件更改完成后就可以直接訓(xùn)練了,需要訓(xùn)練100個(gè)epoch才能得到一個(gè)較好的模型。使用如下的訓(xùn)練腳本開(kāi)始訓(xùn)練。
python train.py --img 1024 --batch 2 --epochs 100 --data ../input/configyolo5/wheat0.yaml --cfg ../input/configyolo5/yolov5x.yaml --name yolov5x_fold0下面是我訓(xùn)練模型的可視化結(jié)果,可以看出來(lái)效果還是不錯(cuò)的,基本沒(méi)有漏標(biāo)或誤判。
預(yù)測(cè)結(jié)果展示
寫(xiě)在后面
以上就是YoloV5的實(shí)戰(zhàn)訓(xùn)練教程,其實(shí)跑起來(lái)還是很簡(jiǎn)單的。大家可以先試下,把訓(xùn)練流程跑通,下篇文章我會(huì)把我的測(cè)試流程和kaggle的提交代碼也分享出來(lái),歡迎大家關(guān)注和轉(zhuǎn)發(fā)。有任何問(wèn)題可以在文章下面評(píng)論,我會(huì)及時(shí)回復(fù)。
總結(jié)
以上是生活随笔為你收集整理的yolov5训练_YoloV5模型训练实战教程:Kaggle全球小麦检测竞赛的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: group by rollup 变量名为
- 下一篇: 0基础怎么学python10010基础怎