Fast-RCNN解析:训练阶段代码导读
轉(zhuǎn)載自:http://blog.csdn.net/linj_m/article/details/48930179#0-tsina-1-35514-397232819ff9a47a7b7e80a40613cfe1
關(guān)于Fast-RCNN的解析,我們將主要分為兩個(gè)部分來介紹,其中一個(gè)是訓(xùn)練部分,這個(gè)部分非常重要,是我們需要重點(diǎn)講解的;另一個(gè)是測(cè)試部分,這個(gè)部分關(guān)系到具體的應(yīng)用,所以也是必須要了解的。本篇博文中,我們先從訓(xùn)練部分講起。
訓(xùn)練階段流程
在官方文檔中,訓(xùn)練階段的啟動(dòng)腳本如下所示:
./tools/train_net.py --gpu 0 --solver models/VGG16/solver.prototxt \--weights data/imagenet_models/VGG16.v2.caffemodel從這段腳本中,我們可以知道,訓(xùn)練的入口函數(shù)就在train_net.py中,其位于fast-rcnn/tools/文件夾內(nèi),我們先來看看這個(gè)文件。
if __name__ == '__main__':args = parse_args()print('Called with args:')print(args)if args.cfg_file is not None:cfg_from_file(args.cfg_file)if args.set_cfgs is not None:cfg_from_list(args.set_cfgs)print('Using config:')pprint.pprint(cfg)if not args.randomize:# fix the random seeds (numpy and caffe) for reproducibilitynp.random.seed(cfg.RNG_SEED)caffe.set_random_seed(cfg.RNG_SEED)# set up caffecaffe.set_mode_gpu()if args.gpu_id is not None:caffe.set_device(args.gpu_id)imdb = get_imdb(args.imdb_name)print 'Loaded dataset `{:s}` for training'.format(imdb.name)roidb = get_training_roidb(imdb)output_dir = get_output_dir(imdb, None)print 'Output will be saved to `{:s}`'.format(output_dir)train_net(args.solver, roidb, output_dir,pretrained_model=args.pretrained_model,max_iters=args.max_iters)從以上的code,我們可以看到,train_net.py的主要處理過程包括以下三個(gè)部分:
(1) 首先對(duì)啟動(dòng)腳本的輸入?yún)?shù)進(jìn)行處理,是通過如下這個(gè)函數(shù)parse_args()進(jìn)行處理的。
def parse_args():"""Parse input arguments"""parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')parser.add_argument('--gpu', dest='gpu_id',help='GPU device id to use [0]', default=0, type=int)parser.add_argument('--solver', dest='solver',help='solver prototxt', default=None, type=str)parser.add_argument('--iters', dest='max_iters',help='number of iterations to train',default=40000, type=int)parser.add_argument('--weights', dest='pretrained_model',help='initialize with pretrained model weights', default=None, type=str)parser.add_argument('--cfg', dest='cfg_file',help='optional config file',default=None, type=str)parser.add_argument('--imdb', dest='imdb_name',help='dataset to train on',default='voc_2007_trainval', type=str)parser.add_argument('--rand', dest='randomize',help='randomize (do not use a fixed seed)',action='store_true')parser.add_argument('--set', dest='set_cfgs',help='set config keys', default=None,nargs=argparse.REMAINDER)if len(sys.argv) == 1:parser.print_help()sys.exit(1)args = parser.parse_args()return args從這個(gè)函數(shù)中,我們可以了解到,訓(xùn)練腳本的可選輸入?yún)?shù)包括:
- –gpu: 這個(gè)參數(shù)指定訓(xùn)練使用的GPU設(shè)備,我的電腦只有一枚GPU,默認(rèn)情況下自動(dòng)開啟,其gpu_id為0;
- –solver: 這個(gè)參數(shù)指定網(wǎng)絡(luò)的優(yōu)化方法,并在其solver的prototxt指向了定義網(wǎng)絡(luò)結(jié)構(gòu)的文件(train.prototxt);
- –weights: 這個(gè)參數(shù)指定了finetune的初始參數(shù),我的電腦GPU不怎么高端,只能使用caffenet進(jìn)行finetune;
- –imdb: 這個(gè)參數(shù)指定了訓(xùn)練所需要的訓(xùn)練數(shù)據(jù),如果你需要訓(xùn)練自己的數(shù)據(jù),那么這個(gè)參數(shù)是必須要指定的;
(2) 然后是根據(jù)輸入的參數(shù)(–imdb 參數(shù)后面指定的數(shù)據(jù))來準(zhǔn)備訓(xùn)練樣本,這個(gè)步驟涉及到兩個(gè)函數(shù):一個(gè) imdb=get_imdb(args.imdb_name) , 另一個(gè)是roidb=get_training_roidb(imdb)。關(guān)于這兩個(gè)函數(shù)我們下部分會(huì)花大時(shí)間來解析,這里先不談。
(3) 最后就是訓(xùn)練函數(shù):train_net(args.solver,roidb, output_dir, pretrained_model= args.pretrained_model, max_iters= args.max_iters)
而這個(gè) train_net() 函數(shù)是從 fast_rcnn/lib/fast_rcnn 文件夾中的 train.py 中 import 進(jìn)來的。那么接下來,我們來看看這個(gè)train.py
這個(gè)函數(shù)主要由一個(gè)類SolverWrapper和兩個(gè)函數(shù)get_training_roidb()和train_net()組成。
首先,我們來看看train_net()函數(shù):
可以發(fā)現(xiàn),該函數(shù)是通過調(diào)用類SolverWrapper來實(shí)現(xiàn)其主要功能的,因此,我們跟進(jìn)到類SolverWrapper的類構(gòu)造函數(shù)中去:
def __init__(self, solver_prototxt, roidb, output_dir,pretrained_model=None):"""Initialize the SolverWrapper."""self.output_dir = output_dirprint 'Computing bounding-box regression targets...'self.bbox_means, self.bbox_stds = \rdl_roidb.add_bbox_regression_targets(roidb)print 'done'self.solver = caffe.SGDSolver(solver_prototxt)if pretrained_model is not None:print ('Loading pretrained model ''weights from {:s}').format(pretrained_model)self.solver.net.copy_from(pretrained_model)self.solver_param = caffe_pb2.SolverParameter()with open(solver_prototxt, 'rt') as f:pb2.text_format.Merge(f.read(), self.solver_param)self.solver.net.layers[0].set_roidb(roidb)初始化完成后,就是要調(diào)用train_model函數(shù)來進(jìn)行網(wǎng)絡(luò)訓(xùn)練,我們來看一下它的主體部分:
def train_model(self, max_iters):"""Network training loop."""last_snapshot_iter = -1timer = Timer()while self.solver.iter < max_iters:# Make one SGD updatetimer.tic()self.solver.step(1)timer.toc()if self.solver.iter % (10 * self.solver_param.display) == 0:print 'speed: {:.3f}s / iter'.format(timer.average_time)if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:last_snapshot_iter = self.solver.iterself.snapshot()if last_snapshot_iter != self.solver.iter:self.snapshot()到此為止,網(wǎng)絡(luò)就可以開始訓(xùn)練了。
訓(xùn)練數(shù)據(jù)處理
不過,關(guān)于Fast-RCNN的重頭戲我們其實(shí)還沒開始——那就是如何準(zhǔn)備訓(xùn)練數(shù)據(jù)。
在上面介紹訓(xùn)練的流程中,與此相關(guān)的函數(shù)是:imdb= get_imdb(args.imdb_name)
這個(gè)函數(shù)是從從lib/datasets/文件夾中的factory.py中import進(jìn)來的,我們來看一下這個(gè)函數(shù):
def get_imdb(name):"""Get an imdb (image database) by name."""if not __sets.has_key(name):raise KeyError('Unknown dataset: {}'.format(name))return __sets[name]()這個(gè)函數(shù)很簡(jiǎn)單,其實(shí)就是根據(jù)字典的key來取得訓(xùn)練數(shù)據(jù)。
那么這個(gè)字典是怎么形成的呢?看下面:
它本質(zhì)上是通過lib/datasets/文件夾下面的inria.py引入的。
所以,現(xiàn)在我們就得開始進(jìn)入inria.py(這個(gè)函數(shù)需要我們自己編寫,可以參考pascal_voc.py編寫)。
首先,我們來看看類inria的構(gòu)造函數(shù):
def __init__(self, image_set, devkit_path):datasets.imdb.__init__(self, image_set)self._image_set = image_setself._devkit_path = devkit_pathself._data_path = os.path.join(self._devkit_path, 'data')self._classes = ('__background__', # always index 0'1001')self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))self._image_ext = ['.jpg', '.png']self._image_index = self._load_image_set_index()# Default to roidb handlerself._roidb_handler = self.selective_search_roidb# Specific config optionsself.config = {'cleanup' : True,'use_salt' : True,'top_k' : 2000}assert os.path.exists(self._devkit_path), \'Devkit path does not exist: {}'.format(self._devkit_path)assert os.path.exists(self._data_path), \'Path does not exist: {}'.format(self._data_path)這里面最要注意的是要根據(jù)自己訓(xùn)練的類別同步修改self._classes,我這里面只有兩類。
類 inria 構(gòu)造完成后,會(huì)調(diào)用函數(shù) roidb,這個(gè)函數(shù)是從類 imdb 中繼承過來的,這個(gè)函數(shù)會(huì)調(diào)用 _roidb_handler 來處理,其中 _roidb_handler=self.selective_search_roidb,下面我們來看看這個(gè)函數(shù):
def selective_search_roidb(self):"""Return the database of selective search regions of interest.Ground-truth ROIs are also included.This function loads/saves from/to a cache file to speed up future calls."""cache_file = os.path.join(self.cache_path,self.name + '_selective_search_roidb.pkl')if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:roidb = cPickle.load(fid)print '{} ss roidb loaded from {}'.format(self.name, cache_file)return roidbif self._image_set != 'test':gt_roidb = self.gt_roidb()ss_roidb = self._load_selective_search_roidb(gt_roidb)roidb = datasets.imdb.merge_roidbs(gt_roidb, ss_roidb)else:roidb = self._load_selective_search_roidb(None)print len(roidb)with open(cache_file, 'wb') as fid:cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)print 'wrote ss roidb to {}'.format(cache_file)return roidb這個(gè)函數(shù)在訓(xùn)練階段會(huì)首先調(diào)用get_roidb() 函數(shù):
def gt_roidb(self):"""Return the database of ground-truth regions of interest.This function loads/saves from/to a cache file to speed up future calls."""cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')if os.path.exists(cache_file):with open(cache_file, 'rb') as fid:roidb = cPickle.load(fid)print '{} gt roidb loaded from {}'.format(self.name, cache_file)return roidbgt_roidb = [self._load_inria_annotation(index)for index in self.image_index]with open(cache_file, 'wb') as fid:cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)print 'wrote gt roidb to {}'.format(cache_file)return gt_roidb如果存在cache_file,那么get_roidb()就會(huì)直接從cache_file中讀取信息;如果不存在cache_file,那么會(huì)調(diào)用_load_inria_annotation()來取得標(biāo)注信息。_load_inria_annotation函數(shù)如下所示:
def _load_inria_annotation(self, index):"""Load image and bounding boxes info from txt files of INRIA Person."""filename = os.path.join(self._data_path, 'Annotations', index + '.xml')print 'Loading: {}'.format(filename)def get_data_from_tag(node, tag):return node.getElementsByTagName(tag)[0].childNodes[0].datawith open(filename) as f:data = minidom.parseString(f.read())objs = data.getElementsByTagName('object')num_objs = len(objs)boxes = np.zeros((num_objs, 4), dtype=np.uint16)gt_classes = np.zeros((num_objs), dtype=np.int32)overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)# Load object bounding boxes into a data frame.for ix, obj in enumerate(objs):# Make pixel indexes 0-basedx1 = float(get_data_from_tag(obj, 'xmin')) - 1y1 = float(get_data_from_tag(obj, 'ymin')) - 1x2 = float(get_data_from_tag(obj, 'xmax')) - 1y2 = float(get_data_from_tag(obj, 'ymax')) - 1# ---------------------------------------------# add these lines to avoid the accertion errorif x1 < 0:x1 = 0if y1 < 0:y1 = 0# ----------------------------------------------cls = self._class_to_ind[str(get_data_from_tag(obj, "name")).lower().strip()]boxes[ix, :] = [x1, y1, x2, y2]gt_classes[ix] = clsoverlaps[ix, cls] = 1.0overlaps = scipy.sparse.csr_matrix(overlaps)return {'boxes' : boxes,'gt_classes': gt_classes,'gt_overlaps' : overlaps,'flipped' : False}當(dāng)處理完標(biāo)注的數(shù)據(jù)后,接下來就要載入SS階段獲得的數(shù)據(jù),通過如下函數(shù)完成:
def _load_selective_search_roidb(self, gt_roidb):filename = os.path.abspath(os.path.join(self._devkit_path,self.name + '.mat'))assert os.path.exists(filename), \'Selective search data not found at: {}'.format(filename)raw_data = sio.loadmat(filename)['boxes'].ravel()box_list = []for i in xrange(raw_data.shape[0]):#這個(gè)地方需要注意,如果在SS中你已經(jīng)變換了box的值,那么就不需要再改變box值的位置了#box_list.append(raw_data[i][:, (1, 0, 3, 2)] - 1)box_list.append(raw_data[i][:, (1, 0, 3, 2)])return self.create_roidb_from_box_list(box_list, gt_roidb)有一點(diǎn)需要注意的是,ss中獲得的box的值,和fast-rcnn中認(rèn)為的box值有點(diǎn)差別,那就是你需要交換box的x和y坐標(biāo)。
總結(jié)
以上是生活随笔為你收集整理的Fast-RCNN解析:训练阶段代码导读的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: ubuntu运行Faster R-CNN
- 下一篇: RCNN SPP-net Fast-