Fast RCNN 训练自己的数据集(3训练和检测)
Fast RCNN 訓練自己的數據集(3訓練和檢測)
轉載請注明出處,樓燚(yì)航的blog,http://www.cnblogs.com/louyihang-loves-baiyan/
https://github.com/YihangLou/fast-rcnn-train-another-dataset?這是我在github上修改的幾個文件的鏈接,求星星啊,求星星啊(原諒我那么不要臉~~)
在之前兩篇文章中我介紹了怎么編譯Fast RCNN,和怎么修改Fast RCNN的讀取數據接口,接下來我來說明一下怎么來訓練網絡和之后的檢測過程
先給看一下極好的檢測效果
https://github.com/YihangLou/fast-rcnn-train-another-dataset
1.預訓練模型介紹
首先在data目錄下,有兩個目錄就是之前在1中解壓好
- fast_rcnn_models/
- imagenet_models/
fast_rcnn_model文件夾下面是作者用fast rcnn訓練好的三個網絡,分別對應著小、中、大型網絡,大家可以試用一下這幾個網絡,看一些檢測效果,他們訓練都迭代了40000次,數據集都是pascal_voc的數據集。
imagenet_model文件夾下面是在Imagenet上訓練好的通用模型,在這里用來初始化網絡的參數
在這里我比較推薦先用中型網絡訓練,中型網絡訓練和檢測的速度都比較快,效果也都比較理想,大型網絡的話訓練速度比較慢,我當時是5000多個標注信息,網絡配置默認,中型網絡訓練大概兩三個小時,大型網絡的話用十幾個小時,需要注意的是網絡訓練最好用GPU,CPU的話太慢了,我當時用的實驗室的服務器,有16塊Tesla K80,用起來真的是灰常爽!
2. 修改模型文件配置
模型文件在models下面對應的網絡文件夾下,在這里我用中型網絡的配置文件修改為例子
比如:我的檢測目標物是car ,那么我的類別就有兩個類別即 background 和 car
因此,首先打開網絡的模型文件夾,打開train.prototxt
修改的地方重要有三個
分別是個地方
OK,如果你要進一步修改網絡訓練中的學習速率,步長,gamma值,以及輸出模型的名字,需要在同目錄下的solver.prototxt中修改。
如下圖:
3.啟動Fast RCNN網絡訓練
啟動訓練:
./tools/train_net.py --gpu 11 --solver models/VGG_CNN_M_1024_LOUYIHANG/solver.prototxt --weights data/imagenet_models/VGG_CNN_M_1024.v2.caffemodel --imdb KakouTrain
參數講解:
- 這里的--是兩個-,markdown寫的,大家不要輸錯
- train_net.py是網絡的訓練文件,之后的參數都是附帶的輸入參數
- --gpu 代表機器上的GPU編號,如果是nvidia系列的tesla顯卡,可以在終端中輸入nvidia-smi來查看當前的顯卡負荷,選擇合適的顯卡
- --solver 代表模型的配置文件,train.prototxt的文件路徑已經包含在這個文件之中
- --weights 代表初始化的權重文件,這里用的是Imagenet上預訓練好的模型,中型的網絡我們選擇用VGG_CNN_M_1024.v2.caffemodel
- --imdb 這里給出的訓練的數據庫名字需要在factory.py的__sets中,我在文件里面有__sets['KakouTrain'],train_net.py這個文件會調用factory.py再生成kakou這個類,來讀取數據
4.啟動Fast RCNN網絡檢測
我修改了tools下面的demo.py這個文件,用來做檢測,并且將檢測的坐標結果輸出到相應的txt文件中
可以看到原始的demo.py 是用網絡測試了兩張圖像,并做可視化輸出,有具體的檢測效果,但是我是在Linux服務器的終端下,沒有display device,因此部分代碼要少做修改
下面是原始的demo.py:
#!/usr/bin/env python# -------------------------------------------------------- # Fast R-CNN # Copyright (c) 2015 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ross Girshick # --------------------------------------------------------""" Demo script showing detections in sample images.See README.md for installation instructions before running. """import _init_paths from fast_rcnn.config import cfg from fast_rcnn.test import im_detect from utils.cython_nms import nms from utils.timer import Timer import matplotlib.pyplot as plt import numpy as np import scipy.io as sio import caffe, os, sys, cv2 import argparseCLASSES = ('__background__','aeroplane', 'bicycle', 'bird', 'boat','bottle', 'bus', 'car', 'cat', 'chair','cow', 'diningtable', 'dog', 'horse','motorbike', 'person', 'pottedplant','sheep', 'sofa', 'train', 'tvmonitor')NETS = {'vgg16': ('VGG16','vgg16_fast_rcnn_iter_40000.caffemodel'),'vgg_cnn_m_1024': ('VGG_CNN_M_1024','vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel'),'caffenet': ('CaffeNet','caffenet_fast_rcnn_iter_40000.caffemodel')}def vis_detections(im, class_name, dets, thresh=0.5):"""Draw detected bounding boxes."""inds = np.where(dets[:, -1] >= thresh)[0]if len(inds) == 0:returnim = im[:, :, (2, 1, 0)]fig, ax = plt.subplots(figsize=(12, 12))ax.imshow(im, aspect='equal')for i in inds:bbox = dets[i, :4]score = dets[i, -1]ax.add_patch(plt.Rectangle((bbox[0], bbox[1]),bbox[2] - bbox[0],bbox[3] - bbox[1], fill=False,edgecolor='red', linewidth=3.5))ax.text(bbox[0], bbox[1] - 2,'{:s} {:.3f}'.format(class_name, score),bbox=dict(facecolor='blue', alpha=0.5),fontsize=14, color='white')ax.set_title(('{} detections with ''p({} | box) >= {:.1f}').format(class_name, class_name,thresh),fontsize=14)plt.axis('off')plt.tight_layout()plt.draw()def demo(net, image_name, classes):"""Detect object classes in an image using pre-computed object proposals."""# Load pre-computed Selected Search object proposalsbox_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo',image_name + '_boxes.mat')obj_proposals = sio.loadmat(box_file)['boxes']# Load the demo imageim_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo', image_name + '.jpg')im = cv2.imread(im_file)# Detect all object classes and regress object boundstimer = Timer()timer.tic()scores, boxes = im_detect(net, im, obj_proposals)timer.toc()print ('Detection took {:.3f}s for ''{:d} object proposals').format(timer.total_time, boxes.shape[0])# Visualize detections for each classCONF_THRESH = 0.8NMS_THRESH = 0.3for cls in classes:cls_ind = CLASSES.index(cls)cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]cls_scores = scores[:, cls_ind]dets = np.hstack((cls_boxes,cls_scores[:, np.newaxis])).astype(np.float32)keep = nms(dets, NMS_THRESH)dets = dets[keep, :]print 'All {} detections with p({} | box) >= {:.1f}'.format(cls, cls,CONF_THRESH)vis_detections(im, cls, dets, thresh=CONF_THRESH)def parse_args():"""Parse input arguments."""parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',default=0, type=int)parser.add_argument('--cpu', dest='cpu_mode',help='Use CPU mode (overrides --gpu)',action='store_true')parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',choices=NETS.keys(), default='vgg16')args = parser.parse_args()return argsif __name__ == '__main__':args = parse_args()prototxt = os.path.join(cfg.ROOT_DIR, 'models', NETS[args.demo_net][0],'test.prototxt')caffemodel = os.path.join(cfg.ROOT_DIR, 'data', 'fast_rcnn_models',NETS[args.demo_net][1])if not os.path.isfile(caffemodel):raise IOError(('{:s} not found.\nDid you run ./data/script/''fetch_fast_rcnn_models.sh?').format(caffemodel))if args.cpu_mode:caffe.set_mode_cpu()else:caffe.set_mode_gpu()caffe.set_device(args.gpu_id)net = caffe.Net(prototxt, caffemodel, caffe.TEST)print '\n\nLoaded network {:s}'.format(caffemodel)print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'print 'Demo for data/demo/000004.jpg'demo(net, '000004', ('car',))print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'print 'Demo for data/demo/001551.jpg'demo(net, '001551', ('sofa', 'tvmonitor'))plt.show()復制這個demo.py 修改成CarFaceTest.py,下面是修改后的文件
修改后的文件主要是添加了outputDetectionResult和runDetection兩個函數, 添加了部分注釋
5.檢測結果
訓練數據集
首先給出我的訓練數據集,其實我的訓練數據集并不是太復雜的
測試數據集
輸出檢測結果到txt文件中,
測試效果
**在復雜場景下的測試效果非常好,速度也非常快,中型網絡監測平均每張在K80顯卡下時0.1~0.2S左右,圖像的尺寸是480*640,6000張測試數據集下達到的準確率是98%!!!**
總結
以上是生活随笔為你收集整理的Fast RCNN 训练自己的数据集(3训练和检测)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Fast RCNN 训练自己数据集 (2
- 下一篇: RCNN (Regions with C