RetinaNet+focal loss
one stage 精度不高,一個主要原因是正負樣本的不平衡,以YOLO為例,每個grid cell有5個預測,本來正負樣本的數量就有差距,再相當于進行5倍放大后,這種數量上的差異更會被放大。
文中提出新的分類損失函數Focal loss,該損失函數通過抑制那些容易分類樣本的權重,將注意力集中在那些難以區分的樣本上,有效控制正負樣本比例,防止失衡現象。也就是focal loss用于解決正負樣本不平衡與難易樣本不平衡的問題.
其中用于控制正負樣本的數量失衡,用于控制簡單/難區分樣本數量失衡。一般=0.25,=2.也就是正樣本loss相對增加,負樣本loss相對減少,負樣本相比正樣本loss減少的倍數為3,同時困難樣本loss相對增加,簡單樣本loss相對減少.
模型采用FPN,P3到P7,其中P7能夠增加對大物體的檢測。
在FPN的P3-P7中分別設置32x32-512x512尺寸不等的anchor,比例設置為{1:2, 1:1, 2:1}。每一層一共有9個anchor,不同層能覆蓋的size范圍為32-813。對每一個anchor,都對應一個K維的one-hot向量(K是類別數)和4維的位置回歸向量。
同時分類子網對A個anchor,每個anchor中的K個類別,都預測一個存在概率。如下圖所示,對于FPN的每一層輸出,對分類子網來說,加上四層3x3x256卷積的FCN網絡,最后一層的卷積稍有不同,用3x3xKA,最后一層維度變為KA表示,對于每個anchor,都是一個K維向量,表示每一類的概率,然后因為one-hot屬性,選取概率得分最高的設為1,其余k-1為歸0。傳統的RPN在分類子網用的是1x1x18,只有一層,而在RetinaNet中,用的是更深的卷積,總共有5層,實驗證明,這種卷積層的加深,對結果有幫助。與分類子網并行,對每一層FPN輸出接上一個位置回歸子網,該子網本質也是FCN網絡,預測的是anchor和它對應的一個GT位置的偏移量。首先也是4層256維卷積,最后一層是4A維度,即對每一個anchor,回歸一個(x,y,w,h)四維向量。注意,此時的位置回歸是類別無關的。分類和回歸子網雖然是相似的結構,但是參數是不共享的。
代碼:
正負樣本計算loss的兩種方式
import torch import torch.nn.functional as Fdef focal_loss_one(alpha, beta, cls_preds, gts):print('======第一種實現方式=======')num_pos = gts.sum()print('==num_pos:', num_pos)alpha_tensor = torch.ones_like(cls_preds) * alphaalpha_tensor = torch.where(torch.eq(gts, 1.), alpha_tensor, 1. - alpha_tensor)print('===alpha_tensor===', alpha_tensor)preds = torch.where(torch.eq(gts, 1.), cls_preds, 1. - cls_preds)print('===1. - preds===', 1. - preds)focal_weight = alpha_tensor * torch.pow((1. - preds), beta)print('==focal_weight:', focal_weight)batch_bce_loss = -(gts * torch.log(cls_preds) + (1. - gts) * torch.log(1. - cls_preds))batch_focal_loss = focal_weight * batch_bce_lossprint('==batch_focal_loss:', batch_focal_loss)batch_focal_loss = batch_focal_loss.sum()print('== batch_focal_loss:', batch_focal_loss)print('==batch_focal_loss.item():', batch_focal_loss.item())if num_pos != 0:mean_batch_focal_loss = batch_focal_loss / num_poselse:mean_batch_focal_loss = batch_focal_lossprint('==mean_batch_focal_loss:', mean_batch_focal_loss)def focal_loss_two(alpha, beta, cls_preds, gts):print('======第二種實現方式=======')pos_inds = (gts == 1.0).float()print('==pos_inds:', pos_inds)neg_inds = (gts != 1.0).float()print('===neg_inds:', neg_inds)pos_loss = -pos_inds * alpha * (1.0 - cls_preds) ** beta * torch.log(cls_preds)neg_loss = -neg_inds * (1 - alpha) * ((cls_preds) ** beta) * torch.log(1.0 - cls_preds)num_pos = pos_inds.float().sum()print('==num_pos:', num_pos)pos_loss = pos_loss.sum()neg_loss = neg_loss.sum()if num_pos == 0:mean_batch_focal_loss = neg_losselse:mean_batch_focal_loss = (pos_loss + neg_loss) / num_posprint('==mean_batch_focal_loss:', mean_batch_focal_loss)def focal_loss_three(alpha, beta, cls_preds, gts):print('======第三種實現方式=======')num_pos = gts.sum()pred_sigmoid = cls_predstarget = gts.type_as(pred_sigmoid)pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(beta)batch_focal_loss = F.binary_cross_entropy(pred_sigmoid, target, reduction='none') * focal_weightbatch_focal_loss = batch_focal_loss.sum()if num_pos != 0:mean_batch_focal_loss = batch_focal_loss / num_poselse:mean_batch_focal_loss = batch_focal_lossprint('==mean_batch_focal_loss:', mean_batch_focal_loss) bs = 2 num_class = 3 alpha = 0.25 beta = 2 # (B, cls) cls_preds = torch.rand([bs, num_class], dtype=torch.float) print('==cls_preds:', cls_preds) gts = torch.tensor([0, 2]) # (B, cls) gts = F.one_hot(gts, num_classes=num_class).type_as(cls_preds) print('===gts===', gts) focal_loss_one(alpha, beta, cls_preds, gts) focal_loss_two(alpha, beta, cls_preds, gts) focal_loss_three(alpha, beta, cls_preds, gts)只有正樣本計算loss:
import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variableclass FocalLoss(nn.Module):"""This criterion is a implemenation of Focal Loss, which is proposed inFocal Loss for Dense Object Detection.Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])The losses are averaged across observations for each minibatch.Args:alpha(1D Tensor, Variable) : the scalar factor for this criteriongamma(float, double) : gamma > 0; reduces the relative loss for well-classi?ed examples (p > .5),putting more focus on hard, misclassi?ed examplessize_average(bool): By default, the losses are averaged over observations for each minibatch.However, if the field size_average is set to False, the losses areinstead summed for each minibatch."""def __init__(self, class_num, alpha=None, gamma=2, size_average=True):super(FocalLoss, self).__init__()if alpha is None:self.alpha = Variable(torch.ones(class_num, 1))else:if isinstance(alpha, Variable):self.alpha = alphaelse:self.alpha = Variable(alpha)self.gamma = gammaself.class_num = class_numself.size_average = size_averagedef forward(self, inputs, targets):N = inputs.size(0)C = inputs.size(1)P = F.softmax(inputs, dim=-1)print('===P:', P)#.data 獲取variable的tensorclass_mask = inputs.data.new(N, C).fill_(0)class_mask = Variable(class_mask)ids = targets.view(-1, 1)class_mask.scatter_(1, ids.data, 1.)#得到onehotprint('==class_mask:', class_mask)if inputs.is_cuda and not self.alpha.is_cuda:self.alpha = self.alpha.cuda()alpha = self.alpha[ids.data.view(-1)]print('==alpha:', alpha)probs = (P*class_mask).sum(1).view(-1, 1)print('==probs:', probs)log_p = probs.log()batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_pif self.size_average:loss = batch_loss.mean()else:loss = batch_loss.sum()return lossdef debug_focal():import numpy as np#只對困難樣本計算lossloss = FocalLoss(class_num=8)#, alpha=torch.tensor([0.25,0.25,0.25,0.25,0.25,0.25,0.25,0.25]).reshape(-1, 1))inputs = torch.rand(2, 8)print('==inputs:', inputs)# print('==inputs.data:', inputs.data)# targets = torch.from_numpy(np.array([[1,0,0,0,0,0,0,0],# [0,1,0,0,0,0,0,0]]))targets = torch.from_numpy(np.array([0, 1]))cost = loss(inputs, targets)print('===cost===:', cost)if __name__ == '__main__':debug_focal()總結
以上是生活随笔為你收集整理的RetinaNet+focal loss的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JavaSE——面向对象基础(思想、类与
- 下一篇: MATLAB中的S-Function的用