DBNet详解
文章目錄
- 創(chuàng)新點
- 算法的整體架構(gòu)
- 自適應(yīng)閾值(Adaptive threshhold)
- 二值化
- 標(biāo)準(zhǔn)二值化
- 可微二值(differentiable Binarization)
- 直觀展示
- 可形變卷積(Deformable convolution)
- 標(biāo)簽的生成
- PSENet標(biāo)簽生成
- DBNet標(biāo)簽生成
- 損失函數(shù)
- 后處理
- 代碼閱讀
- 數(shù)據(jù)預(yù)處理
- 入口
- AugmentDetectionData(數(shù)據(jù)增強(qiáng)類)
- RandomCropData(數(shù)據(jù)裁剪類)
- MakeICDARData(數(shù)據(jù)重新組織類)
- MakeSegDetectionData(生成概率圖和對應(yīng)mask類)
- MakeBorderMap(生成閾值圖和對應(yīng)Mask類)
- NormalizeImage
- FilterKeys
- 模型結(jié)構(gòu)
- 骨干網(wǎng)絡(luò)和FPN
- head部分(decoder)
- binary
- thresh
- step_function
- 損失函數(shù)
- binary loss
- thresh loss
- thresh_binary loss
- 邏輯推理
- 補(bǔ)充
- 語義分割中的loss function
- cross entropy loss
- weighted loss
- focal loss
- dice soft loss
- Dice系數(shù)計算
- Dice loss
- 梯度分析
- 總結(jié)
- soft IOU loss
- 總結(jié)
- 總結(jié)
- soft IOU loss
- 總結(jié)
創(chuàng)新點
? 本文的最大創(chuàng)新點。在基于分割的文本檢測網(wǎng)絡(luò)中,最終的二值化map都是使用的固定閾值來獲取,并且閾值不同對性能影響較大。本文中,對每一個像素點進(jìn)行自適應(yīng)二值化,二值化閾值由網(wǎng)絡(luò)學(xué)習(xí)得到,徹底將二值化這一步驟加入到網(wǎng)絡(luò)里一起訓(xùn)練,這樣最終的輸出圖對于閾值就會非常魯棒。
和常規(guī)基于語義分割算法的區(qū)別是多了一條threshold map分支,該分支的主要目的是和分割圖聯(lián)合得到更接近二值化的二值圖,屬于輔助分支。其余操作就沒啥了。整個核心知識就這些了。
算法的整體架構(gòu)
- 首先,圖像輸入特征提取主干,提取特征;
- 其次,特征金字塔上采樣到相同的尺寸,并進(jìn)行特征級聯(lián)得到特征F;
- 然后,特征F用于預(yù)測概率圖(probability map P)和閾值圖(threshold map T)
- 最后,通過P和F計算近似二值圖(approximate binary map B)
在訓(xùn)練期間對P,T,B進(jìn)行監(jiān)督訓(xùn)練,P和B是用的相同的監(jiān)督信號(label)。在推理時,只需要P或B就可以得到文本框。
網(wǎng)絡(luò)輸出:
1.probability map, w*h*1 , 代表像素點是文本的概率
2.threshhold map, w*h*1, 每個像素點的閾值
3.binary map, w*h*1, 由1,2計算得到,計算公式為DB公式
自適應(yīng)閾值(Adaptive threshhold)
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-V5RaeceH-1610966579215)(C:\F\notebook\DB\20200922201346491.png)]
文中指出傳統(tǒng)的文本檢測算法主要是圖中藍(lán)色線,處理流程如下:
- 首先,通過設(shè)置一個固定閾值將分割網(wǎng)絡(luò)訓(xùn)練得到的概率圖(segmentation map)轉(zhuǎn)化為二值圖(binarization map);
- 然后,使用一些啟發(fā)式技術(shù)(例如像素聚類)將像素分組為文本實例。
而DBNet使用紅色線,思路:
通過網(wǎng)絡(luò)去預(yù)測圖片每個位置處的閾值,而不是采用一個固定的值,這樣就可以很好將背景與前景分離出來,但是這樣的操作會給訓(xùn)練帶來梯度不可微的情況,對此對于二值化提出了一個叫做Differentiable Binarization來解決不可微的問題。
? 閾值圖(threshhold map)使用流程如圖2所示,使用閾值map和不使用閾值map的效果對比如圖6所示,從圖6?中可以看到,即使沒用帶監(jiān)督的閾值map,閾值map也會突出顯示文本邊界區(qū)域,這說明邊界型閾值map對最終結(jié)果是有利的。所以,本文在閾值map上選擇監(jiān)督訓(xùn)練,已達(dá)到更好的表現(xiàn)
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-4gDERWPU-1610966579221)(C:\F\notebook\DB\20200922201612829.png)]
二值化
標(biāo)準(zhǔn)二值化
? 一般使用分割網(wǎng)絡(luò)(segmentation network)產(chǎn)生的概率圖(probability map P),將P轉(zhuǎn)化為一個二值圖P,當(dāng)像素為1的時候,認(rèn)定其為有效的文本區(qū)域,同時二值處理過程:
i和j代表了坐標(biāo)點的坐標(biāo),t是預(yù)定義的閾值;
可微二值(differentiable Binarization)
公式1是不可微的,所以沒法直接用于訓(xùn)練,本文提出可微的二值化函數(shù),如下(其實就是一個帶系數(shù)的sigmoid):
就是近似二值圖;T代表從網(wǎng)絡(luò)中學(xué)習(xí)到的自適應(yīng)閾值圖;k是膨脹因子(經(jīng)驗性設(shè)置k=50).
? 這個近似的二值化函數(shù)的表現(xiàn)類似于標(biāo)準(zhǔn)的二值化函數(shù),如圖4所表示,但是因為可微,所以可以直接用于網(wǎng)絡(luò)訓(xùn)練,基于自適應(yīng)閾值可微二值化不僅可以幫助區(qū)分文本區(qū)域和背景,而且可以將連接緊密的文本實例分離出來。
? 為了說明DB模塊的引入對于聯(lián)合訓(xùn)練的優(yōu)勢,作者對該函數(shù)進(jìn)行梯度分析,也就是對approximate
binary map進(jìn)行求導(dǎo)分析,由于是sigmod輸出,故假設(shè)Loss是bce,對于label為0或者1的位置,其Loss函數(shù)可以重寫為:
x表示probability map-threshold map,最后一層關(guān)于x的梯度很容易計算:
? 看上圖右邊,(b)圖是當(dāng)label=1,x預(yù)測值從-1到1的梯度,可以發(fā)現(xiàn),當(dāng)k=50時候梯度遠(yuǎn)遠(yuǎn)大于k=1,錯誤的區(qū)域梯度更大,對于label=0的情況分析也是一樣的。故:
(1) 通過增加參數(shù)K,就可以達(dá)到增大梯度的目的,加快收斂
(2) 在預(yù)測錯誤位置,梯度也是顯著增加
總之通過引入DB模塊,通過參數(shù)K可以達(dá)到增加梯度幅值,更加有利優(yōu)化,可以使得三個輸出圖優(yōu)化更好,最終分割結(jié)果會優(yōu)異。而DB模塊本身就是帶參數(shù)的sigmod函數(shù),實現(xiàn)如下:
直觀展示
p可以理解,就是有文字的區(qū)域有值0.9以上,沒有文字區(qū)域黑的為0 .
T是一個只有文字邊界才有值的,其他地方為0 .
? 分別是原圖,gt圖,threshold map圖。 這里再說下threshold map圖,非文字邊界處都是灰色的,這是因為統(tǒng)一加了0.3,所有最小值是0.3.
這里其實還看不清,我們把src+gt+threshold map看看。
可以看到:
- p的ground truth是標(biāo)注縮水之后
- T的ground truth是文字塊邊緣分別向內(nèi)向外收縮和擴(kuò)張
- p與T是公式里面的那兩個變量。
再看這個公式與曲線圖:
P和T我們就用ground truth帶入來理解:
? P網(wǎng)絡(luò)學(xué)的文字塊內(nèi)部, T網(wǎng)絡(luò)學(xué)的文字邊緣,兩者計算得到B。 B的ground truth也是標(biāo)注縮水之后,和p用的同一個。 在實際操作中,作者把除了文字塊邊緣的區(qū)域置為0.3.應(yīng)該就是為了當(dāng)在非文字區(qū)域, P=0,T=0.3,x=p-T<0這樣拉到負(fù)半軸更有利于區(qū)分。
可形變卷積(Deformable convolution)
? 可變形卷積可以提供模型一個靈活的感受野,這對于不同縱橫比的文本很有利,本文應(yīng)用可變形卷積,使用3×3卷積核在ResNet-18或者ResNet-50的conv3,conv4,conv5層。
標(biāo)簽的生成
概率圖的標(biāo)簽產(chǎn)成法類似PSENet
PSENet標(biāo)簽生成
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-8eRCYRHv-1610966579234)(C:\F\notebook\DB\20200923193744225.png)]
? 網(wǎng)絡(luò)輸出多個分割結(jié)果(S1,Sn),因此訓(xùn)練時需要有多個GY與其匹配,在本文中,通過收縮原始標(biāo)簽就可以簡單高效的生成不同尺度的GT,如圖5所示,(b)代表原始的標(biāo)注結(jié)果,也表示最大的分割標(biāo)簽mask,即Sn,利用Vatti裁剪算法獲取其他尺度的Mask,如圖5(a),將原始多邊形pn 縮小di 像素到 pi ,收縮后的pi 轉(zhuǎn)換成0/1的二值mask作為GT,用G1,G2,,,,Gn分別代表不同尺度的GT,用數(shù)學(xué)方式表示的話,尺度比例為ri 。
di 的計算方式為:
di=Area(Pn)?(1?ri2)/Perimeter(pn)d_i=Area(P_n)*(1-r_i^2)/Perimeter(p_n) di?=Area(Pn?)?(1?ri2?)/Perimeter(pn?)
Area(·) 是計算多邊形面積的函數(shù), Perimeter(·)是計算多邊形周長的函數(shù),生成Gi時的尺度比例ri計算公式為:
ri=1?(1?m)?(n?i)/(n?1)r_i=1-(1-m)*(n-i)/(n-1) ri?=1?(1?m)?(n?i)/(n?1)
m代表最小的尺度比例,取值范圍是(0,1],使用上式,通過m和n兩個超參數(shù)可以計算出r1,r2,…rn,他們隨著m變現(xiàn)線性增加到最大值1.
DBNet標(biāo)簽生成
給定一張圖片,文本區(qū)域標(biāo)注的多邊形可以描述為:
G={Sk}k=1nG=\{S_k\}_{k=1}^{n} G={Sk?}k=1n?
n是每隔文本框的標(biāo)注點總數(shù),在不同數(shù)據(jù)中可能不同,然后使用vatti裁剪算法,將正樣例區(qū)域產(chǎn)生通過收縮polygon從G到Gs,補(bǔ)償公式計算
D:offset;L:周長;A:面積;r:收縮比例,設(shè)置為0.4;
損失函數(shù)
損失函數(shù)為概率map的loss、二值map的loss和閾值map的loss之和。
Ls 是概率map的loss,Lb 是二值map的loss,均使用二值交叉熵loss(BCE),為了解決正負(fù)樣本不均衡問題,使用hard negative mining, α和β分別設(shè)置為1.0和10 .
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-udDAMvqX-1610966579237)(C:\F\notebook\DB\2020092220283134.png)]
Sl 設(shè)計樣本集,其中正陽樣本和負(fù)樣本比例是1:3
Lt計算方式為擴(kuò)展文本多邊形Gd內(nèi)預(yù)測結(jié)果和標(biāo)簽之間的L1距離之和:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-pGt5zzIG-1610966579239)(C:\F\notebook\DB\20200922203558285.png)]
Rd是在膨脹Gd內(nèi)像素的索引,y*是閾值map的標(biāo)簽
后處理
(由于threshold map的存在,probability map的邊界可以學(xué)習(xí)的很好,因此可以直接按照收縮的方式(Vatti clipping algorithm)擴(kuò)張回去 )
在推理時可以采用概率圖或近似二值圖來生成文本框,為了方便作者選擇了概率圖,具體步驟如下:
1、使用固定閾值0.2將概率圖做二值化得到二值化圖;
2、由二值化圖得到收縮文字區(qū)域;
3、將收縮文字區(qū)域按Vatti clipping算法的偏移系數(shù)D’通過膨脹再擴(kuò)展回來。
D‘就是擴(kuò)展補(bǔ)償,A’是收縮多邊形的面積,L‘就是收縮多邊形的周長,r’作者設(shè)置的是1.5;
(注意r‘的值在DBNet工程中不是1.5,而在我自己的數(shù)據(jù)集上,參數(shù)設(shè)置為1.3較合適,大家訓(xùn)練的時候可以根據(jù)自己模型效果進(jìn)行調(diào)整)
文中說明DB算法的主要優(yōu)勢有以下4點:
- 在五個基準(zhǔn)數(shù)據(jù)集上有良好的表現(xiàn),其中包括水平、多個方向、彎曲的文本。
- 比之前的方法要快很多,因為DB可以提供健壯的二值化圖,從而大大簡化了后處理過程。
- 使用輕量級的backbone(ResNet18)也有很好的表現(xiàn)。
- DB模塊在推理過程中可以去除,因此不占用額外的內(nèi)存和時間的消耗。
參考:
論文鏈接:https://arxiv.org/pdf/1911.08947.pdf
工程鏈接:https://github.com/MhLiao/DB
? https://github.com/WenmuZhou/DBNet.pytorch
- https://blog.csdn.net/qq_22764813/article/details/107785388
- https://blog.csdn.net/qq_39707285/article/details/108739010
- https://zhuanlan.zhihu.com/p/94677957
- https://mp.weixin.qq.com/s/ehbROyE-grp_F3T3YBX9CA
代碼閱讀
數(shù)據(jù)預(yù)處理
入口
在data/image_dataset.py,數(shù)據(jù)預(yù)處理邏輯非常簡單,就是讀取圖片和gt標(biāo)注,解析出每張圖片poly標(biāo)注,包括多邊形標(biāo)注、字符內(nèi)容以及是否是忽略文本,忽略文本一般是比較模糊和小的文本。
具體可以在getitem方法里面插入:
ImageDataset.__getitem__():data_process(data)預(yù)處理配置:
processes:- class: AugmentDetectionDataaugmenter_args:- ['Fliplr', 0.5]- {'cls': 'Affine', 'rotate': [-10, 10]}- ['Resize', [0.5, 3.0]]only_resize: Falsekeep_ratio: False- class: RandomCropDatasize: [640, 640]max_tries: 10- class: MakeICDARData- class: MakeSegDetectionData- class: MakeBorderMap- class: NormalizeImage- class: FilterKeyssuperfluous: ['polygons', 'filename', 'shape', 'ignore_tags', 'is_training']預(yù)處理流程:
AugmentDetectionData(數(shù)據(jù)增強(qiáng)類)
DB/data/processes/augment_data.py
? 其目的就是對圖片和poly標(biāo)注進(jìn)行數(shù)據(jù)增強(qiáng),包括翻轉(zhuǎn)、旋轉(zhuǎn)和縮放三個,參數(shù)如配置所示。本文采用的增強(qiáng)庫是imgaug。可以看出本文訓(xùn)練階段對數(shù)據(jù)是不保存比例的resize,然后再進(jìn)行三種增強(qiáng)。
由于icdar數(shù)據(jù),文本區(qū)域占比都是非常小的,故不能用直接resize到指定輸入大小的數(shù)據(jù)增強(qiáng)操作,而是使用后續(xù)的randcrop操作比較科學(xué)。但是如果自己項目的數(shù)據(jù)文本區(qū)域比較大,則可能沒必要采用RandomCropData這么復(fù)雜的數(shù)據(jù)增強(qiáng)操作,直接resize算了。
RandomCropData(數(shù)據(jù)裁剪類)
DB/data/processes/random_crop_data.py
因為數(shù)據(jù)裁剪涉及到比較復(fù)雜的多變形標(biāo)注后處理,所以單獨(dú)列出來 。
? 其目的是對圖片進(jìn)行裁剪到指定的[640, 640]。由于斜框的特點,裁剪增強(qiáng)沒那么容易做,本文采用的裁剪策略非常簡單: 遍歷每一個多邊形標(biāo)注,只要裁剪后有至少有一個poly還在裁剪框內(nèi),則認(rèn)為該次裁剪有效。這個策略主要可以保證一張圖片中至少有一個gt,且實現(xiàn)比較簡單。
其具體流程是:
代碼如下:
def crop_area(self, im, text_polys):h, w = im.shape[:2]h_array = np.zeros(h, dtype=np.int32)w_array = np.zeros(w, dtype=np.int32)#將poly數(shù)據(jù)進(jìn)行水平和垂直方向投影,有標(biāo)注的地方是1,其余地方是0for points in text_polys:points = np.round(points, decimals=0).astype(np.int32)minx = np.min(points[:, 0])maxx = np.max(points[:, 0])w_array[minx:maxx] = 1miny = np.min(points[:, 1])maxy = np.max(points[:, 1])h_array[miny:maxy] = 1# ensure the cropped area not across a text#找出沒有標(biāo)注的水平和垂直坐標(biāo)h_axis = np.where(h_array == 0)[0]w_axis = np.where(w_array == 0)[0]#如果所有位置都有標(biāo)注,則無法裁剪,直接原圖返回if len(h_axis) == 0 or len(w_axis) == 0:return 0, 0, w, h#對水平和垂直坐標(biāo)進(jìn)行連續(xù)區(qū)域分離,其實就是把所有連續(xù)0坐標(biāo)區(qū)域切割處理#后面進(jìn)行隨機(jī)裁剪都是在每個連續(xù)區(qū)域進(jìn)行,可以最大程度保證不會裁斷標(biāo)注h_regions = self.split_regions(h_axis)w_regions = self.split_regions(w_axis)for i in range(self.max_tries):if len(w_regions) > 1:#先從n個區(qū)域隨機(jī)選擇2個區(qū)域,然后在兩個區(qū)域內(nèi)部隨機(jī)選擇兩個點,構(gòu)成x方向最大最小坐標(biāo)xmin, xmax = self.region_wise_random_select(w_regions, w)else:xmin, xmax = self.random_select(w_axis, w)if len(h_regions) > 1:#h方向也是一樣處理ymin, ymax = self.region_wise_random_select(h_regions, h)else:ymin, ymax = self.random_select(h_axis, h)#不能裁剪的過小if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:# area too smallcontinuenum_poly_in_rect = 0for poly in text_polys:#如果有一個poly標(biāo)注沒有出界,則直接返回,表示裁剪成功if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):num_poly_in_rect += 1breakif num_poly_in_rect > 0:return xmin, ymin, xmax - xmin, ymax - yminreturn 0, 0, w, h? 在得到裁剪區(qū)域后,就比較簡單了。先對裁剪區(qū)域圖片進(jìn)行保存長寬比的resize,最長邊為網(wǎng)絡(luò)輸入,例如640x640, 然后從上到下pad,得到640x640的圖片
# 計算crop區(qū)域 crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys) # crop 圖片 保持比例填充 scale_w = self.size[0] / crop_w scale_h = self.size[1] / crop_h scale = min(scale_w, scale_h) h = int(crop_h * scale) w = int(crop_w * scale)padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype) padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) img = padimg如果進(jìn)行可視化,會顯示如下所示:
可以看出,這種裁剪策略雖然簡單暴力,但是為了拼接成640x640的輸出,會帶來大量無關(guān)全黑像素區(qū)域。
MakeICDARData(數(shù)據(jù)重新組織類)
DB/data/processes/make_icdar_data.py
就是簡單的組織數(shù)據(jù)而已
#Making ICDAE format #返回值: OrderedDict(image=data['image'],polygons=polygons,ignore_tags=ignore_tags,shape=shape,filename=filename,is_training=data['is_training'])MakeSegDetectionData(生成概率圖和對應(yīng)mask類)
DB/data/processes/make_seg_detection_data.py
功能:將多邊形數(shù)據(jù)轉(zhuǎn)化為mask格式即概率圖gt,并且標(biāo)記哪些多邊形是忽略區(qū)域
#Making binary mask from detection data with ICDAR format 輸入:image,polygons,ignore_tags,filename 輸出:gt(shape:[1,h,w]),mask (shape:[h,w])(用于后面計算binary loss)
? 為了防止標(biāo)注間相互粘連,不好后處理,區(qū)分實例,目前做法都是會進(jìn)行shrink即沿著多邊形標(biāo)注的每條邊進(jìn)行向內(nèi)縮減一定像素,得到縮減的gt,然后才進(jìn)行訓(xùn)練;在測試時候再采用相反的手動還原回來。
? 縮減做法采用的也是常規(guī)的Vatti clipping algorithm,是通過pyclipper庫實現(xiàn)的,縮減比例是默認(rèn)0.4,公式是:
r=0.4,A是多邊形面積,L是多邊形周長,通過該公式就可以對每個不同大小的多邊形計算得到一個唯一的D,代表每條邊的向內(nèi)縮放像素個數(shù)。
gt = np.zeros((1, h, w), dtype=np.float32)#shrink后得到概率圖,包括所有區(qū)域mask = np.ones((h, w), dtype=np.float32)#指示哪些區(qū)域是忽略區(qū)域,0就是忽略區(qū)域for i in range(len(polygons)):polygon = polygons[i]height = max(polygon[:, 1]) - min(polygon[:, 1])width = max(polygon[:, 0]) - min(polygon[:, 0])#如果是忽略樣本,或者高寬過小,則mask對應(yīng)位置設(shè)置為0即可if ignore_tags[i] or min(height, width) < self.min_text_size:cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)ignore_tags[i] = Trueelse:#沿著每條邊進(jìn)行shrinkpolygon_shape = Polygon(polygon)#多邊形分析庫#每條邊收縮距離:polygon, D=A(1-r^2)/Ldistance = polygon_shape.area * \(1 - np.power(self.shrink_ratio, 2)) / polygon_shape.lengthsubject = [tuple(l) for l in polygons[i]]#實現(xiàn)坐標(biāo)的偏移padding = pyclipper.PyclipperOffset()padding.AddPath(subject, pyclipper.JT_ROUND,pyclipper.ET_CLOSEDPOLYGON)shrinked = padding.Execute(-distance)#得到縮放后的多邊形if shrinked == []:cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)ignore_tags[i] = Truecontinueshrinked = np.array(shrinked[0]).reshape(-1, 2)cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)如果進(jìn)行可視化,如下所示:
? 概率圖內(nèi)部全白區(qū)域就是概率圖的label,右圖是忽略區(qū)域mask,0為忽略區(qū)域,到時候該區(qū)域是不計算概率圖loss的。
MakeBorderMap(生成閾值圖和對應(yīng)Mask類)
DB/data/make_border_map.py
功能:計算閾值圖和對應(yīng)mask。
輸入:預(yù)處理后的image info: image,polygons,ignore_tags 輸出:thresh_map,thresh_mask (用于后面計算thresh loss)? 仔細(xì)看閾值圖的標(biāo)注,首先紅線點是poly標(biāo)注;然后對該多邊形先進(jìn)行shrink操作,得到藍(lán)線; 然后向外反向shrink同樣的距離,得到綠色;閾值圖就是綠線和藍(lán)色區(qū)域,以紅線為起點,計算在綠線和藍(lán)線區(qū)域內(nèi)的點距離紅線的距離,故為距離圖。
其代碼的處理邏輯是:
流程:
canvas = np.zeros(image.shape[:2], dtype=np.float32) mask = np.zeros(image.shape[:2], dtype=np.float32)draw_border_map(polygons[i], canvas, mask=mask) canvas = canvas * (0.7 - 0.3) + 0.3 data['thresh_map'] = canvas data['thresh_mask'] = maskdraw_border_map
#處理每條polydef draw_border_map(self, polygon, canvas, mask):polygon = np.array(polygon)assert polygon.ndim == 2assert polygon.shape[1] == 2#向外擴(kuò)展polygon_shape = Polygon(polygon)distance = polygon_shape.area * \(1 - np.power(self.shrink_ratio, 2)) / polygon_shape.lengthsubject = [tuple(l) for l in polygon]padding = pyclipper.PyclipperOffset()padding.AddPath(subject, pyclipper.JT_ROUND,pyclipper.ET_CLOSEDPOLYGON)padded_polygon = np.array(padding.Execute(distance)[0])#shape:[12,2]擴(kuò)大和縮減一樣的像素cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)#內(nèi)部全部填充1#計算最小包圍poly矩形xmin = padded_polygon[:, 0].min()xmax = padded_polygon[:, 0].max()ymin = padded_polygon[:, 1].min()ymax = padded_polygon[:, 1].max()width = xmax - xmin + 1height = ymax - ymin + 1#裁剪掉無關(guān)區(qū)域,加快計算速度polygon[:, 0] = polygon[:, 0] - xminpolygon[:, 1] = polygon[:, 1] - ymin#最小包圍矩形的所有位置坐標(biāo)xs = np.broadcast_to(np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))ys = np.broadcast_to(np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)for i in range(polygon.shape[0]):#對每條邊進(jìn)行遍歷j = (i + 1) % polygon.shape[0]#計算圖片上所有點到線上面的距離absolute_distance = self.distance(xs, ys, polygon[i], polygon[j])#僅僅保留0-1之間的位置,得到距離圖distance_map[i] = np.clip(absolute_distance / distance, 0, 1)distance_map = distance_map.min(axis=0)#繪制到原圖上xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)#如果有多個ploy實例重合,則該區(qū)域取最大值canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(1 - distance_map[ymin_valid-ymin:ymax_valid-ymax+height,xmin_valid-xmin:xmax_valid-xmax+width],canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])可視化如下所示:
采用matpoltlib繪制距離圖會更好看
NormalizeImage
DB/data/processes/normalize_image.py
圖片歸一化類
FilterKeys
DB/data/processes/filter_keys.py
字典數(shù)據(jù)過濾類,具體是把superfluous里面的key和value刪掉,不輸入網(wǎng)絡(luò)中
#刪除無用的圖片信息,只保留信息: dict("image","gt","mask","thresh_map","thresh_mask")模型結(jié)構(gòu)
DB/structure/model.py
模型結(jié)構(gòu)配置部分:
builder: class: Buildermodel: SegDetectorModelmodel_args:backbone: deformable_resnet18decoder: SegDetectordecoder_args: adaptive: Truein_channels: [64, 128, 256, 512]k: 50骨干網(wǎng)絡(luò)和FPN
? 骨架網(wǎng)絡(luò)采用的是resnet18或者resnet50,為了增加網(wǎng)絡(luò)特征提取能力,在layer2、layer3和layer4模塊內(nèi)部引入了變形卷積dcnv2模塊。在resnet輸出的4個特征圖后面采用標(biāo)準(zhǔn)的FPN網(wǎng)絡(luò)結(jié)構(gòu),得到4個增強(qiáng)后輸出,然后cat進(jìn)來,得到1/4的特征圖輸出fuse。
? 其中,resnet骨架特征提取代碼在backbones/resnet.py里,具體是輸出x2, x3, x4, x5,分別是1/4~1/32尺寸。FPN部分代碼在decoders/seg_detector.py里面.
head部分(decoder)
DB/decoders/seg_detector.py
? 輸出head在訓(xùn)練時候包括三個分支,分別是probability map、threshold map和經(jīng)過DB模塊計算得到的approximate binary map。三個圖通道都是1,輸出和輸入是一樣大的。要想分割精度高,高分辨率輸出是必要的。
**輸出:**binary、thresh、thresh_binary
fuse = torch.cat((p5, p4, p3, p2), 1) #推理時,只需返回binary binary = self.binarize(fuse) thresh = self.thresh(fuse) thresh_binary = self.step_function(binary, thresh)binary
? 對fuse特征圖經(jīng)過一系列卷積和反卷積,擴(kuò)大到和原圖一樣大的輸出,然后經(jīng)過sigmod層得到0-1輸出概率圖probability map
self.binarize = nn.Sequential(nn.Conv2d(inner_channels, inner_channels //4, 3, padding=1, bias=bias),BatchNorm2d(inner_channels//4),nn.ReLU(inplace=True),nn.ConvTranspose2d(inner_channels//4, inner_channels//4, 2, 2),BatchNorm2d(inner_channels//4),nn.ReLU(inplace=True),nn.ConvTranspose2d(inner_channels//4, 1, 2, 2),nn.Sigmoid())self.binarize.apply(self.weights_init)thresh
? 同時對fuse特征圖采用類似上采樣操作,經(jīng)過sigmod層的0-1輸出閾值圖threshold map
if adaptive:self.thresh = self._init_thresh(inner_channels, serial=serial, smooth=smooth, bias=bias)self.thresh.apply(self.weights_init)def _init_thresh(self, inner_channels,serial=False, smooth=False, bias=False):in_channels = inner_channelsif serial:in_channels += 1self.thresh = nn.Sequential(nn.Conv2d(in_channels, inner_channels //4, 3, padding=1, bias=bias),BatchNorm2d(inner_channels//4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, inner_channels//4, smooth=smooth, bias=bias),BatchNorm2d(inner_channels//4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),nn.Sigmoid())return self.threshstep_function
? 將這兩個輸出圖經(jīng)過DB模塊得到approximate binary map
torch.reciprocal(1 + torch.exp(-self.k * (binary - thresh)))損失函數(shù)
DB/decoders/seg_detector_loss.py
? 輸出是單個單通道圖,probability map和approximate binary map是典型的分割輸出,故其loss就是普通的bce,但是為了平衡正負(fù)樣本,還額外采用了難負(fù)樣本采樣策略,對背景區(qū)域和前景區(qū)域采用3:1的設(shè)置。對于threshold map,其輸出不一定是0-1之間,后面會介紹其值的范圍,當(dāng)前采用的是L1 loss,且僅僅計算擴(kuò)展后的多邊形內(nèi)部區(qū)域,其余區(qū)域忽略。
Ls是概率圖,Lt是閾值圖,Lb是近似二值化圖,
? 本文整個論文Loss的實現(xiàn)在decoders/seg_detector_loss.py的L1BalanceCELoss類,可以發(fā)現(xiàn)其實approximate binary map采用的并不是論文中的bce,而是可以克服正負(fù)樣本平衡的dice loss。一般在高度不平衡的二值分割任務(wù)中,dice loss效果會比純bce好,但是更好的策略是dice loss +bce loss。
loss = dice_loss + 10 * l1_loss + 5*bce_lossbinary loss
bce_loss = self.bce_loss(pred['binary'], batch['gt'], batch['mask'])bce_loss:
DB/decoders/balance_cross_entropy_loss.py
def forward(self,pred: torch.Tensor,gt: torch.Tensor,mask: torch.Tensor,return_origin=False):'''Args:pred: shape :math:`(N, 1, H, W)`, the prediction of networkgt: shape :math:`(N, 1, H, W)`, the targetmask: shape :math:`(N, H, W)`, the mask indicates positive regions'''positive = (gt * mask).byte()negative = ((1 - gt) * mask).byte()positive_count = int(positive.float().sum())#負(fù)樣本個數(shù)為positive_count的self.negative_ratio倍數(shù)negative_count = min(int(negative.float().sum()),int(positive_count * self.negative_ratio))loss = nn.functional.binary_cross_entropy(pred, gt, reduction='none')[:, 0, :, :]positive_loss = loss * positive.float()negative_loss = loss * negative.float()#按照loss選擇topK個negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)balance_loss = (positive_loss.sum() + negative_loss.sum()) /\(positive_count + negative_count + self.eps)if return_origin:return balance_loss, lossreturn balance_lossthresh loss
l1_loss, l1_metric = self.l1_loss(pred['thresh'], batch['thresh_map'], batch['thresh_mask'])l1_loss:
DB/decoders/l1_loss.py
class MaskL1Loss(nn.Module):def __init__(self):super(MaskL1Loss, self).__init__()def forward(self, pred: torch.Tensor, gt, mask):mask_sum = mask.sum()if mask_sum.item() == 0:return mask_sum, dict(l1_loss=mask_sum)else:loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sumreturn loss, dict(l1_loss=loss)thresh_binary loss
dice_loss = self.dice_loss(pred['thresh_binary'], batch['gt'], batch['mask'])dice_loss:
DB/decoders/dice_loss.py
class DiceLoss(nn.Module):'''Loss function from https://arxiv.org/abs/1707.03237,where iou computation is introduced heatmap manner to measure thediversity bwtween tow heatmaps.'''def __init__(self, eps=1e-6):super(DiceLoss, self).__init__()self.eps = epsdef forward(self, pred: torch.Tensor, gt, mask, weights=None):'''pred: one or two heatmaps of shape (N, 1, H, W),the losses of tow heatmaps are added together.gt: (N, 1, H, W)mask: (N, H, W)'''assert pred.dim() == 4, pred.dim()return self._compute(pred, gt, mask, weights)def _compute(self, pred, gt, mask, weights):if pred.dim() == 4:pred = pred[:, 0, :, :]gt = gt[:, 0, :, :]assert pred.shape == gt.shapeassert pred.shape == mask.shapeif weights is not None:assert weights.shape == mask.shapemask = weights * maskintersection = (pred * gt * mask).sum()union = (pred * mask).sum() + (gt * mask).sum() + self.epsloss = 1 - 2.0 * intersection / unionassert loss <= 1return lossbinary與thresh_binary的標(biāo)簽都是用的gt
thresh的標(biāo)簽用的thresh_map
邏輯推理
配置如下:
- name: validate_dataclass: ImageDatasetdata_dir:- '/remote_workspace/ocr/public_dataset/icdar2015/'data_list:- '/remote_workspace/ocr/public_dataset/icdar2015/test_list.txt'processes:- class: AugmentDetectionDataaugmenter_args:- ['Resize', {'width': 1280, 'height': 736}]# - ['Resize', {'width': 2048, 'height': 1152}]only_resize: Truekeep_ratio: False- class: MakeICDARData- class: MakeSegDetectionData- class: NormalizeImage? 如果不考慮label,則其處理邏輯和訓(xùn)練邏輯有一點不一樣,其把圖片統(tǒng)一resize到指定的長度進(jìn)行預(yù)測。
前面說過閾值圖分支其實可以相當(dāng)于輔助分支,可以聯(lián)合優(yōu)化各個分支性能。故在測試時候發(fā)現(xiàn)概率圖預(yù)測值已經(jīng)蠻好了,故在測試階段實際上把閾值圖分支移除了,只需要概率圖輸出即可。
后處理邏輯在structure/representers/seg_detector_representer.py,本文特色就是后處理比較簡單,故流程為:
采用作者提供的訓(xùn)練好的權(quán)重進(jìn)行預(yù)測,可視化預(yù)測結(jié)果如下所示:
論文中指標(biāo)結(jié)果:
可以看出變形卷積和閾值圖對整個性能都有比較大的促進(jìn)作用。
測試icdar2015數(shù)據(jù)結(jié)果:
補(bǔ)充
語義分割中的loss function
cross entropy loss
用于圖像語義分割任務(wù)的最常用損失函數(shù)是像素級別的交叉熵?fù)p失,這種損失會逐個檢查每個像素,將對每個像素類別的預(yù)測結(jié)果(概率分布向量)與我們的獨(dú)熱編碼標(biāo)簽向量進(jìn)行比較。
假設(shè)我們需要對每個像素的預(yù)測類別有5個,則預(yù)測的概率分布向量長度為5:
每個像素對應(yīng)的損失函數(shù)為:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-OeRSbkGm-1610966579269)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bpixel+loss%7D+%3D±%5Csum_%7Bclasses%7D+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29+%5C%5C)]
整個圖像的損失就是對每個像素的損失求平均值。
特別注意的是,binary entropy loss 是針對類別只有兩個的情況,簡稱 bce loss,損失函數(shù)公式為:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-DOMD9pF8-1610966579270)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bbce+loss%7D+%3D±+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29%5C%5C)]
weighted loss
由于交叉熵?fù)p失會分別評估每個像素的類別預(yù)測,然后對所有像素的損失進(jìn)行平均,因此我們實質(zhì)上是在對圖像中的每個像素進(jìn)行平等地學(xué)習(xí)。如果多個類在圖像中的分布不均衡,那么這可能導(dǎo)致訓(xùn)練過程由像素數(shù)量多的類所主導(dǎo),即模型會主要學(xué)習(xí)數(shù)量多的類別樣本的特征,并且學(xué)習(xí)出來的模型會更偏向?qū)⑾袼仡A(yù)測為該類別。
FCN論文和U-Net論文中針對這個問題,對輸出概率分布向量中的每個值進(jìn)行加權(quán),即希望模型更加關(guān)注數(shù)量較少的樣本,以緩解圖像中存在的類別不均衡問題。
比如對于二分類,正負(fù)樣本比例為1: 99,此時模型將所有樣本都預(yù)測為負(fù)樣本,那么準(zhǔn)確率仍有99%這么高,但其實該模型沒有任何使用價值。
為了平衡這個差距,就對正樣本和負(fù)樣本的損失賦予不同的權(quán)重,帶權(quán)重的二分類損失函數(shù)公式如下:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-PtRPVOrh-1610966579271)(https://www.zhihu.com/equation?tex=%5Ctext+%7Bpos_weight%7D+%3D+%5Cfrac%7B%5Ctext+%7Bnum_neg%7D%7D%7B%5Ctext+%7Bnum_pos%7D%7D+%5C%5C+%5Ctext+%7Bloss%7D+%3D±+%5Ctext+%7Bpos_weight%7D+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29%5C%5C)]
要減少假陰性樣本的數(shù)量,可以增大 pos_weight;要減少假陽性樣本的數(shù)量,可以減小 pos_weight。
focal loss
上面針對不同類別的像素數(shù)量不均衡提出了改進(jìn)方法,但有時還需要將像素分為難學(xué)習(xí)和容易學(xué)習(xí)這兩種樣本。
容易學(xué)習(xí)的樣本模型可以很輕松地將其預(yù)測正確,模型只要將大量容易學(xué)習(xí)的樣本分類正確,loss就可以減小很多,從而導(dǎo)致模型不怎么顧及難學(xué)習(xí)的樣本,所以我們要想辦法讓模型更加關(guān)注難學(xué)習(xí)的樣本。
對于較難學(xué)習(xí)的樣本,將 bce loss 修改為:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-6eZmpbcM-1610966579272)(https://www.zhihu.com/equation?tex=-+%281-y_%7Bpred%7D%29%5E%5Cgamma+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+y_%7Bpred%7D%5E%5Cgamma+%5Ctimes+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29+%5C%5C)]
其中的 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-pOuXJT0r-1610966579273)(https://www.zhihu.com/equation?tex=%5Cgamma)] 通常設(shè)置為2。
舉個例子,預(yù)測一個正樣本,如果預(yù)測結(jié)果為0.95,這是一個容易學(xué)習(xí)的樣本,有 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-w4wQuNnV-1610966579274)(https://www.zhihu.com/equation?tex=%281-0.95%29%5E2%3D0.0025)] ,損失直接減少為原來的1/400。
而如果預(yù)測結(jié)果為0.4,這是一個難學(xué)習(xí)的樣本,有 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-QQu9pL2M-1610966579276)(https://www.zhihu.com/equation?tex=%281-0.5%29%5E2%3D0.25)] ,損失減小為原來的1/4,雖然也在減小,但是相對來說,減小的程度小得多。
所以通過這種修改,就可以使模型更加專注于學(xué)習(xí)難學(xué)習(xí)的樣本。
而將這個修改和對正負(fù)樣本不均衡的修改合并在一起,就是大名鼎鼎的 focal loss:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-23Dr9Qbl-1610966579278)(https://www.zhihu.com/equation?tex=-+%5Calpha+%281-y_%7Bpred%7D%29%5E%5Cgamma+%5Ctimes+y_%7Btrue%7D+log+%28y_%7Bpred%7D%29±+%281-%5Calpha%29+y_%7Bpred%7D%5E%5Cgamma+%5Ctimes+%281-y_%7Btrue%7D%29+log+%281-y_%7Bpred%7D%29+%5C%5C)]
dice soft loss
Dice系數(shù)計算
語義分割任務(wù)中常用的還有一個基于 Dice 系數(shù)的損失函數(shù),該系數(shù)實質(zhì)上是兩個樣本之間重疊的度量。此度量范圍為 0~1,其中 Dice 系數(shù)為1表示完全重疊。Dice 系數(shù)最初是用于二進(jìn)制數(shù)據(jù)的,可以計算為:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-SGCRjBn0-1610966579279)(https://www.zhihu.com/equation?tex=Dice+%3D+%5Cfrac+%7B2+%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C%7D+%5C%5C)]
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-TDJx3q9u-1610966579281)(https://www.zhihu.com/equation?tex=%7CA+%5Ccap+B%7C)] 代表集合A和B之間的公共元素,并且 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-Y7BLfLdh-1610966579284)(https://www.zhihu.com/equation?tex=%7C+A+%7C)] 代表集合A中的元素數(shù)量(對于集合B同理)。
對于在預(yù)測的分割掩碼上評估 Dice 系數(shù),我們可以將 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-wvhJvFjz-1610966579286)(https://www.zhihu.com/equation?tex=%7CA+%5Ccap+B%7C)] 近似為預(yù)測掩碼和標(biāo)簽掩碼之間的逐元素乘法,然后對結(jié)果矩陣求和。
計算 Dice 系數(shù)的分子中有一個2,那是因為分母中對兩個集合的元素個數(shù)求和,兩個集合的共同元素被加了兩次。
Dice loss
為了設(shè)計一個可以最小化的損失函數(shù),可以簡單地使用 [外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-zTwxiWv4-1610966579288)(https://www.zhihu.com/equation?tex=1-Dice+)]。 這種損失函數(shù)被稱為 soft Dice loss,這是因為我們直接使用預(yù)測出的概率,而不是使用閾值將其轉(zhuǎn)換成一個二進(jìn)制掩碼。
Dice loss是針對前景比例太小的問題提出的,dice系數(shù)源于二分類,本質(zhì)上是衡量兩個樣本的重疊部分。
對于二分類問題,一般預(yù)測值分為以下幾種:
- TP: true positive,真陽性,預(yù)測是陽性,預(yù)測對了,實際也是正例。
- TN: true negative,真陰性,預(yù)測是陰性,預(yù)測對了,實際也是負(fù)例。
- FP: false positive,假陽性,預(yù)測是陽性,預(yù)測錯了,實際是負(fù)例。
- FN: false negative,假陰性,預(yù)測是陰性,預(yù)測錯了,實際是正例。
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-I6TNYlF0-1610966579290)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JEcEJoM2ZGMkd2cjhxbHM2eG04Z2JBUURyUHIyT1VIN2ljWGVSWGdDckVjUVJteDBMTXI4bURBLzY0MA.png)]
這里dice coefficient可以寫成如下形式:
dice=2TP2TP+FP+FNdice=\frac{2TP}{2TP+FP+FN} dice=2TP+FP+FN2TP?
而我們知道:
可見dice coefficient是等同**「F1 score」,直觀上dice coefficient是計算 與 的相似性,本質(zhì)上則同時隱含precision和recall兩個指標(biāo)。可見dice loss是直接優(yōu)化「F1 score」**。
對于神經(jīng)網(wǎng)絡(luò)的輸出,分子與我們的預(yù)測和標(biāo)簽之間的共同激活有關(guān),而分母分別與每個掩碼中的激活數(shù)量有關(guān),這具有根據(jù)標(biāo)簽掩碼的尺寸對損失進(jìn)行歸一化的效果。
對于每個類別的mask,都計算一個 Dice 損失:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-v2aiaP5D-1610966579294)(https://www.zhihu.com/equation?tex=1-+%5Cfrac+%7B2+%5Csum%5Climits_%7Bpixels%7D+y_%7Btrue%7D+y_%7Bpred%7D%7D%7B%5Csum%5Climits_%7Bpixels%7D+%28y_%7Btrue%7D%5E2+%2B+y_%7Bpred%7D%5E2%29%7D+%5C%5C)]
將每個類的 Dice 損失求和取平均,得到最后的 Dice soft loss。
梯度分析
從dice loss的定義可以看出,dice loss 是一種**「區(qū)域相關(guān)」**的loss。意味著某像素點的loss以及梯度值不僅和該點的label以及預(yù)測值相關(guān),和其他點的label以及預(yù)測值也相關(guān),這點和ce (交叉熵cross entropy) ?loss 不同。
dice loss 是應(yīng)用于語義分割而不是分類任務(wù),并且是一個區(qū)域相關(guān)的loss,因此更適合針對多點的情況進(jìn)行分析。由于多點輸出的情況比較難用曲線呈現(xiàn),這里使用模擬預(yù)測值的形式觀察梯度的變化。
下圖為原始圖片和對應(yīng)的label:[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-xPLvYa7F-1610966579296)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JGMWR3blNGU1R5VEY4VFllNHN3SHBrR1FOM3JrWnRQamtYZGhoWjBydWo3RFFyamlibmowZ3lBLzY0MA.png)]
為了便于梯度可視化,這里對梯度求絕對值操作,因為我們關(guān)注的是梯度的大小而非方向。另外梯度值都乘以 保證在容易辨認(rèn)的范圍。
首先定義如下熱圖,值越大,顏色越亮,反之亦然:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-BeCmQcqD-1610966579298)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0I1YXBtUDZNWGJaNDhocklkWmE3dHpGdEZKQmJwSFV6Q0tqTUhWRW5mQ3MyTmh1b2o4TTJTNVEvNjQw.png)]
預(yù)測值變化( 值,圖上的數(shù)字為預(yù)測值區(qū)間):
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-jY3HOS0H-1610966579299)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JiT3FrQzRMYXZJMThMbVdxQVNXTmE3STdjR2EwMm95cnB6cVhuZTRMNWhwajJDOWRySXUyS2cvNjQw.png)]
dice loss 對應(yīng) 值的梯度:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-pXdKApFF-1610966579301)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0JrMXV2Sjhyem1qanMyMXdraWJzYkRBbktiSlVqTXFjaWFYSUt1VkJSaWFDd213TGZpYTMyanFUaWFuQS82NDA.png)]
ce loss 對應(yīng) 值的梯度:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-i9VOLACe-1610966579302)(C:\F\notebook\DB\aHR0cHM6Ly9tbWJpei5xcGljLmNuL21tYml6X3BuZy9pYVRhOHV0NkhpYXdBZWpDcGhDVGtpY3EyVlRaaWJJTTBDR0I1aFdVUWliVk1CaWFtRzFxbXdocnp6ZUVqWTY1dmFhZWtlV05iMGVJcGJBUkNpYkoyMFdHUmliZmJRLzY0MA.png)]
可以看出:
- 一般情況下,dice loss 正樣本的梯度大于背景樣本的,尤其是剛開始網(wǎng)絡(luò)預(yù)測接進(jìn)0.5的時候。說明dice loss 更具有指向性,更加偏向于正樣本,保證有較低的FN。
- 負(fù)樣本(背景區(qū)域)也會產(chǎn)生梯度
- 極端情況下,網(wǎng)絡(luò)預(yù)測接進(jìn)0或1時,對應(yīng)點梯度值極小,dice loss 存在梯度飽和現(xiàn)象。此時預(yù)測失敗(FN,FP)的情況很難扭轉(zhuǎn)回來。不過該情況出現(xiàn)的概率較低,因為網(wǎng)絡(luò)初始化輸出接近0.5,此時具有較大的梯度值。而網(wǎng)絡(luò)通過梯度下降的方式更新參數(shù),只會逐漸削弱預(yù)測失敗的像素點。
- 對于ce loss,當(dāng)前的點的梯度僅和當(dāng)前預(yù)測值與label的距離相關(guān),預(yù)測越接近label,梯度越小。當(dāng)網(wǎng)絡(luò)預(yù)測接近0或1時,梯度依然保持該特性。
- 對比發(fā)現(xiàn),訓(xùn)練前中期,dice loss 下正樣本的梯度值相對于ce loss ,顏色更亮,值更大。說明dice loss對挖掘正樣本更加有優(yōu)勢。
【dice loss為何能夠解決正負(fù)樣本不平衡問題?】
因為dice loss 是一個區(qū)域相關(guān)的loss。區(qū)域相關(guān)的意思就是,當(dāng)前像素的loss不光和當(dāng)前像素的預(yù)測值相關(guān),和其他點的值也相關(guān)。dice loss的求交的形式可以理解為mask掩碼操作,因此不管圖片有多大,固定大小的正樣本的區(qū)域計算的loss是一樣的,對網(wǎng)絡(luò)起到的監(jiān)督貢獻(xiàn)不會隨著圖片的大小而變化。從上圖可視化也發(fā)現(xiàn),訓(xùn)練更傾向于挖掘前景區(qū)域,正負(fù)樣本不平衡的情況就是前景占比較小。而ce loss 會公平處理正負(fù)樣本,當(dāng)出現(xiàn)正樣本占比較小時,就會被更多的負(fù)樣本淹沒。
【dice loss背景區(qū)域能否起到監(jiān)督作用?】
可以的,但是會小于前景區(qū)域。和直觀理解不同的是,隨著訓(xùn)練的進(jìn)行,背景區(qū)域也能產(chǎn)生較為可觀的梯度。這點和單點的情況分析不同。這里求偏導(dǎo),當(dāng)t_i=0 時:
可以看出, 背景區(qū)域的梯度是存在的,只有預(yù)測值命中的區(qū)域極小時, 背景梯度才會很小.
【dice loss 為何訓(xùn)練會很不穩(wěn)定?】
在使用dice loss時,一般正樣本為小目標(biāo)時會產(chǎn)生嚴(yán)重的震蕩。因為在只有前景和背景的情況下,小目標(biāo)一旦有部分像素預(yù)測錯誤,那么就會導(dǎo)致loss值大幅度的變動,從而導(dǎo)致梯度變化劇烈。可以假設(shè)極端情況,只有一個像素為正樣本,如果該像素預(yù)測正確了,不管其他像素預(yù)測如何,loss 就接近0,預(yù)測錯誤了,loss 接近1。而對于ce loss,loss的值是總體求平均的,更多會依賴負(fù)樣本的地方。
總結(jié)
dice loss 對正負(fù)樣本嚴(yán)重不平衡的場景有著不錯的性能,訓(xùn)練過程中更側(cè)重對前景區(qū)域的挖掘。但訓(xùn)練loss容易不穩(wěn)定,尤其是小目標(biāo)的情況下。另外極端情況會導(dǎo)致梯度飽和現(xiàn)象。因此有一些改進(jìn)操作,主要是結(jié)合ce loss等改進(jìn),比如: ?dice+ce loss,dice + focal loss等,
soft IOU loss
前面我們知道計算 Dice 系數(shù)的公式,其實也可以表示為:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-2X1iB7aM-1610966579304)(https://www.zhihu.com/equation?tex=Dice+%3D+%5Cfrac+%7B2+%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C%7D+%3D+%5Cfrac+%7B2+TP%7D%7B2+TP+%2B+FP+%2B+FN%7D+%5C%5C)]
其中 TP 為真陽性樣本,FP 為假陽性樣本,FN 為假陰性樣本。分子和分母中的 TP 樣本都加了兩次。
IoU 的計算公式和這個很像,區(qū)別就是 TP 只計算一次:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-sN52A4f0-1610966579306)(https://www.zhihu.com/equation?tex=IoU+%3D+%5Cfrac+%7B%7CA+%5Ccap+B%7C%7D%7B%7CA%7C+%2B+%7CB%7C±+%7CA+%5Ccap+B%7C%7D+%3D+%5Cfrac+%7BTP%7D%7BTP+%2B+FP+%2B+FN%7D+%5C%5C)]
和 Dice soft loss 一樣,通過 IoU 計算損失也是使用預(yù)測的概率值:
[外鏈圖片轉(zhuǎn)存失敗,源站可能有防盜鏈機(jī)制,建議將圖片保存下來直接上傳(img-efgoYMzv-1610966579307)(https://www.zhihu.com/equation?tex=loss+%3D±+%5Cfrac+%7B1%7D%7B%7CC%7C%7D+%5Csum%5Climits_c+%5Cfrac+%7B%5Csum%5Climits_%7Bpixels%7D+y_%7Btrue%7D+y_%7Bpred%7D%7D%7B%5Csum%5Climits_%7Bpixels%7D+%28y_%7Btrue%7D+%2B+y_%7Bpred%7D±+y_%7Btrue%7D+y_%7Bpred%7D%29%7D+%5C%5C)]
其中 C 表示總的類別數(shù)。
總結(jié)
交叉熵?fù)p失把每個像素都當(dāng)作一個獨(dú)立樣本進(jìn)行預(yù)測,而 dice loss 和 iou loss 則以一種更“整體”的方式來看待最終的預(yù)測輸出。
預(yù)測值相關(guān),和其他點的值也相關(guān)。dice loss的求交的形式可以理解為mask掩碼操作,因此不管圖片有多大,固定大小的正樣本的區(qū)域計算的loss是一樣的,對網(wǎng)絡(luò)起到的監(jiān)督貢獻(xiàn)不會隨著圖片的大小而變化。從上圖可視化也發(fā)現(xiàn),訓(xùn)練更傾向于挖掘前景區(qū)域,正負(fù)樣本不平衡的情況就是前景占比較小。而ce loss 會公平處理正負(fù)樣本,當(dāng)出現(xiàn)正樣本占比較小時,就會被更多的負(fù)樣本淹沒。
【dice loss背景區(qū)域能否起到監(jiān)督作用?】
可以的,但是會小于前景區(qū)域。和直觀理解不同的是,隨著訓(xùn)練的進(jìn)行,背景區(qū)域也能產(chǎn)生較為可觀的梯度。這點和單點的情況分析不同。這里求偏導(dǎo),當(dāng)t_i=0 時:
可以看出, 背景區(qū)域的梯度是存在的,只有預(yù)測值命中的區(qū)域極小時, 背景梯度才會很小.
【dice loss 為何訓(xùn)練會很不穩(wěn)定?】
在使用dice loss時,一般正樣本為小目標(biāo)時會產(chǎn)生嚴(yán)重的震蕩。因為在只有前景和背景的情況下,小目標(biāo)一旦有部分像素預(yù)測錯誤,那么就會導(dǎo)致loss值大幅度的變動,從而導(dǎo)致梯度變化劇烈。可以假設(shè)極端情況,只有一個像素為正樣本,如果該像素預(yù)測正確了,不管其他像素預(yù)測如何,loss 就接近0,預(yù)測錯誤了,loss 接近1。而對于ce loss,loss的值是總體求平均的,更多會依賴負(fù)樣本的地方。
總結(jié)
dice loss 對正負(fù)樣本嚴(yán)重不平衡的場景有著不錯的性能,訓(xùn)練過程中更側(cè)重對前景區(qū)域的挖掘。但訓(xùn)練loss容易不穩(wěn)定,尤其是小目標(biāo)的情況下。另外極端情況會導(dǎo)致梯度飽和現(xiàn)象。因此有一些改進(jìn)操作,主要是結(jié)合ce loss等改進(jìn),比如: ?dice+ce loss,dice + focal loss等,
soft IOU loss
前面我們知道計算 Dice 系數(shù)的公式,其實也可以表示為:
[外鏈圖片轉(zhuǎn)存中…(img-2X1iB7aM-1610966579304)]
其中 TP 為真陽性樣本,FP 為假陽性樣本,FN 為假陰性樣本。分子和分母中的 TP 樣本都加了兩次。
IoU 的計算公式和這個很像,區(qū)別就是 TP 只計算一次:
[外鏈圖片轉(zhuǎn)存中…(img-sN52A4f0-1610966579306)]
和 Dice soft loss 一樣,通過 IoU 計算損失也是使用預(yù)測的概率值:
[外鏈圖片轉(zhuǎn)存中…(img-efgoYMzv-1610966579307)]
其中 C 表示總的類別數(shù)。
總結(jié)
交叉熵?fù)p失把每個像素都當(dāng)作一個獨(dú)立樣本進(jìn)行預(yù)測,而 dice loss 和 iou loss 則以一種更“整體”的方式來看待最終的預(yù)測輸出。
這兩類損失是針對不同情況,各有優(yōu)點和缺點,在實際應(yīng)用中,可以同時使用這兩類損失來進(jìn)行互補(bǔ)。
總結(jié)
- 上一篇: 基于matlab的杨氏双缝干涉模拟仿真+
- 下一篇: Logit Adjust