faster-rcnn代码阅读-训练整体流程
二、訓(xùn)練
接下來(lái)回到train.py第160行,通過調(diào)用sw.train_model方法進(jìn)行訓(xùn)練:
1 def train_model(self, max_iters):
2 """Network training loop."""
3 last_snapshot_iter = -1
4 timer = Timer()
5 model_paths = []
6 while self.solver.iter < max_iters:
7 # Make one SGD update
8 timer.tic()
9 self.solver.step(1)
10 timer.toc()
11 if self.solver.iter % (10 * self.solver_param.display) == 0:
12 print 'speed: {:.3f}s / iter'.format(timer.average_time)
13
14 if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
15 last_snapshot_iter = self.solver.iter
16 model_paths.append(self.snapshot())
17
18 if last_snapshot_iter != self.solver.iter:
19 model_paths.append(self.snapshot())
20 return model_paths
方法中的self.solver.step(1)即是網(wǎng)絡(luò)進(jìn)行一次前向傳播和反向傳播。前向傳播時(shí),數(shù)據(jù)流會(huì)從第一層流動(dòng)到最后一層,最后計(jì)算出loss,然后loss相對(duì)于各層輸入的梯度會(huì)從最后一層計(jì)算回第一層。下面逐層來(lái)介紹faster-rcnn算法的運(yùn)行過程。
2.1、input-data layer
第一層是由python代碼構(gòu)成的,其prototxt描述為:
layer {
name: 'input-data'
type: 'Python'
top: 'data'
top: 'im_info'
top: 'gt_boxes'
python_param {
module: 'roi_data_layer.layer'
layer: 'RoIDataLayer'
param_str: "'num_classes': 2"
}
}
從中可以看出,input-data層有三個(gè)輸出:data、im_info、gt_boxes,其實(shí)現(xiàn)為RoIDataLayer類。這一層對(duì)數(shù)據(jù)的預(yù)處理操作為:對(duì)圖片進(jìn)行長(zhǎng)寬等比例縮放,使短邊縮放至600;如果縮放后,長(zhǎng)邊的長(zhǎng)度大于1000,則以長(zhǎng)邊為基準(zhǔn),將長(zhǎng)邊縮放至1000,短邊作相應(yīng)的等比例縮放。這一層的3個(gè)輸出分別為:
1、data:1, 3, h, w(一個(gè)batch只支持輸入一張圖)
2、im_info: im_info[0], im_info[1], im_info[2]分別為h, w, target_size/im_origin_size(縮放比例)
3、gt_boxes: (x1, y1, x2, y2, cls)
預(yù)處理部分涉及到的函數(shù)有_get_next_minibatch,get_minibatch,_get_image_blob,prep_im_for_blob,im_list_to_blob。
網(wǎng)絡(luò)在構(gòu)造過程中(即self.solver = caffe.SGDSolver(solver_prototxt))會(huì)調(diào)用該類的setup方法:
1 __C.TRAIN.IMS_PER_BATCH = 1
2 __C.TRAIN.SCALES = [600]
3 __C.TRAIN.MAX_SIZE = 1000
4 __C.TRAIN.HAS_RPN = True
5 __C.TRAIN.BBOX_REG = True
6
7 def setup(self, bottom, top):
8 """Setup the RoIDataLayer."""
9
10 # parse the layer parameter string, which must be valid YAML
11 layer_params = yaml.load(self.param_str_)
12
13 self._num_classes = layer_params['num_classes']
14
15 self._name_to_top_map = {}
16
17 # data blob: holds a batch of N images, each with 3 channels
18 idx = 0
19 top[idx].reshape(cfg.TRAIN.IMS_PER_BATCH, 3,
20 max(cfg.TRAIN.SCALES), cfg.TRAIN.MAX_SIZE)
21 self._name_to_top_map['data'] = idx
22 idx += 1
23
24 if cfg.TRAIN.HAS_RPN:
25 top[idx].reshape(1, 3)
26 self._name_to_top_map['im_info'] = idx
27 idx += 1
28
29 top[idx].reshape(1, 4)
30 self._name_to_top_map['gt_boxes'] = idx
31 idx += 1
32 else: # not using RPN
33 # rois blob: holds R regions of interest, each is a 5-tuple
34 # (n, x1, y1, x2, y2) specifying an image batch index n and a
35 # rectangle (x1, y1, x2, y2)
36 top[idx].reshape(1, 5)
37 self._name_to_top_map['rois'] = idx
38 idx += 1
39
40 # labels blob: R categorical labels in [0, ..., K] for K foreground
41 # classes plus background
42 top[idx].reshape(1)
43 self._name_to_top_map['labels'] = idx
44 idx += 1
45
46 if cfg.TRAIN.BBOX_REG:
47 # bbox_targets blob: R bounding-box regression targets with 4
48 # targets per class
49 top[idx].reshape(1, self._num_classes * 4)
50 self._name_to_top_map['bbox_targets'] = idx
51 idx += 1
52
53 # bbox_inside_weights blob: At most 4 targets per roi are active;
54 # thisbinary vector sepcifies the subset of active targets
55 top[idx].reshape(1, self._num_classes * 4)
56 self._name_to_top_map['bbox_inside_weights'] = idx
57 idx += 1
58
59 top[idx].reshape(1, self._num_classes * 4)
60 self._name_to_top_map['bbox_outside_weights'] = idx
61 idx += 1
62
63 print 'RoiDataLayer: name_to_top:', self._name_to_top_map
64 assert len(top) == len(self._name_to_top_map)
主要是對(duì)輸出的shape進(jìn)行定義。要說明的是,在前向傳播的過程中,仍然會(huì)對(duì)輸出的各top的shape進(jìn)行重定義,并且二者定義的shape往往都是不同的。
總結(jié)
以上是生活随笔為你收集整理的faster-rcnn代码阅读-训练整体流程的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Monkey脚本编写
- 下一篇: Tableview中Dynamic Pr