【CV】10分钟理解Focal loss数学原理与Pytorch代码
原文鏈接:https://amaarora.github.io/2020/06/29/FocalLoss.html
原文作者:Aman Arora
Focal loss 是一個在目標檢測領域常用的損失函數。最近看到一篇博客,趁這個機會,學習和翻譯一下,與大家一起交流和分享。
在這篇博客中,我們將會理解什么是Focal loss,并且什么時候應該使用它。同時我們會深入理解下其背后的數學原理與pytorch 實現.
什么是Focal loss,它是用來干嘛的?
為什么Focal loss有效,其中的原理是什么?
Alpha and Gamma?
怎么在代碼中實現它?
Credits
什么是Focal loss,它是用來干嘛的?
在了解什么是Focal Loss以及有關它的所有詳細信息之前,我們首先快速直觀地了解Focal Loss的實際作用。Focal loss最早是 He et al 在論文 Focal Loss for Dense Object Detection 中實現的。
在這篇文章發表之前,對象檢測實際上一直被認為是一個很難解決的問題,尤其是很難檢測圖像中的小尺寸對象。請參見下面的示例,與其他圖片相比,摩托車的尺寸相對較小, 所以該模型無法很好地預測摩托車的存在。
fig-1??bce?
在上圖中,模型無法預測摩托車的原因是因為該模型是使用了Binary Cross Entropy loss,這種訓練目標要求模型 對自己的預測真的很有信心。而Focal Loss所做的是,它使模型可以更"放松"地預測事物,而無需80-100%確信此對象是“某物”。簡而言之,它給模型提供了更多的自由,可以在進行預測時承擔一些風險。這在處理高度不平衡的數據集時尤其重要,因為在某些情況下(例如癌癥檢測),即使預測結果為假陽性也可接受,確實需要模型承擔風險并盡量進行預測。
因此,Focal loss在樣本不平衡的情況下特別有用。特別是在“對象檢測”的情況下,大多數像素通常都是背景,圖像中只有很少數的像素具有我們感興趣的對象。
這是經過Focal loss訓練后同一模型對同樣圖片的預測。
fig-2??focal loss prediction
分析這兩者并觀察其中的差異,可能是個很好的主意。這將有助于我們對于Focal loss進行直觀的了解。
那么為什么Focal loss有效,其中的原理是什么?
既然我們已經看到了“Focal loss”可以做什么的一個例子,接下來讓我們嘗試去理解為什么它可以起作用。下面是了解Focal loss的最重要的一張圖:
fig-3 FL vs CE
在上圖中,“藍”線代表交叉熵損失。X軸即“預測為真實標簽的概率”(為簡單起見,將其稱為pt)。舉例來說,假設模型預測某物是自行車的概率為0.6,而它確實是自行車, 在這種情況下的pt為0.6。而如果同樣的情況下對象不是自行車。則pt為0.4,因為此處的真實標簽是0,而對象不是自行車的概率為0.4(1-0.6)。
Y軸是給定pt后Focal loss和CE的loss的值。
從圖像中可以看出,當模型預測為真實標簽的概率為0.6左右時,交叉熵損失仍在0.5左右。因此,為了在訓練過程中減少損失,我們的模型將必須以更高的概率來預測到真實標簽。換句話說,交叉熵損失要求模型對自己的預測非常有信心。但這也同樣會給模型表現帶來負面影響。
深度學習模型會變得過度自信, 因此模型的泛化能力會下降.
這個模型過度自信的問題同樣在另一篇出色的論文 Beyond temperature scaling: Obtaining well-calibrated multiclass probabilities with Dirichlet calibration 被強調過。
另外,作為重新思考計算機視覺的初始架構的一部分而引入的標簽平滑是解決該問題的另一種方法。
Focal loss與上述解決方案不同。從比較Focal loss與CrossEntropy的圖表可以看出,當使用γ> 1的Focal Loss可以減少“分類得好的樣本”或者說“模型預測正確概率大”的樣本的訓練損失,而對于“難以分類的示例”,比如預測概率小于0.5的,則不會減小太多損失。因此,在數據類別不平衡的情況下,會讓模型的注意力放在稀少的類別上,因為這些類別的樣本見過的少,比較難分。
Focal loss的數學定義如下:
Alpha and Gamma?
那么在Focal loss 中的alpha和gamma是什么呢?我們會將alpha記為α,gamma記為γ。
我們可以這樣來理解fig3
γ?控制曲線的形狀.?γ的值越大, 好分類樣本的loss就越小, 我們就可以把模型的注意力投向那些難分類的樣本. 一個大的?γ?讓獲得小loss的樣本范圍擴大了.
同時,當γ=0時,這個表達式就退化成了Cross Entropy Loss,眾所周知地
定義“ pt”如下,按照其真實意義:
將上述兩個式子合并,Cross Entropy Loss其實就變成了下式。
現在我們知道了γ的作用,那么α是干什么的呢?
除了Focal loss以外,另一種處理類別不均衡的方法是引入權重。給稀有類別以高權重,給統治地位的類或普通類以小權重。這些權重我們也可以用α表示。
alpha-CE
加上了這些權重確實幫助處理了類別的 不均衡,focal loss的論文報道:
類間不均衡較大會導致,交叉熵損失在訓練的時候收到影響。易分類的樣本的分類錯誤的損失占了整體損失的絕大部分,并主導梯度。盡管α平衡了正面/負面例子的重要性,但它并未區分簡單/困難例子。
作者想要解釋的是:
盡管我們加上了α, 它也確實對不同的類別加上了不同的權重, 從而平衡了正負樣本的重要性 ,但在大多數例子中,只做這個是不夠的. 我們同樣要做的是減少容易分類的樣本分類錯誤的損失。因為不然的話,這些容易分類的樣本就主導了我們的訓練.
那么Focal loss 怎么處理的呢,它相對交叉熵加上了一個乘性的因子(1 ? pt)**γ,從而像我們上面所講的,降低了易分類樣本區間內產生的loss。
再看下Focal loss的表達,是不是清晰了許多。
怎么在代碼中實現呢?
這是Focal loss在Pytorch中的實現。
class WeightedFocalLoss(nn.Module):"Non weighted version of Focal Loss"def __init__(self, alpha=.25, gamma=2):super(WeightedFocalLoss, self).__init__()self.alpha = torch.tensor([alpha, 1-alpha]).cuda()self.gamma = gammadef forward(self, inputs, targets):BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')targets = targets.type(torch.long)at = self.alpha.gather(0, targets.data.view(-1))pt = torch.exp(-BCE_loss)F_loss = at*(1-pt)**self.gamma * BCE_lossreturn F_loss.mean()如果你理解了alpha和gamma的意思,那么這個實現應該都能理解。同時,像文章中提到的一樣,這里是對BCE進行因子的相乘。
Credits
貼上作者的 twitter ,當然如果大家有什么問題討論,也可以在公眾號留言。
fig-1?and?fig-2?are from the?Fastai 2018 course?Lecture-09!
未完待續
今天給大家分享到這里,感謝大家的閱讀和支持,我們會繼續給大家分享我們的所思所想所學,希望大家都有收獲!
往期精彩回顧適合初學者入門人工智能的路線及資料下載機器學習及深度學習筆記等資料打印機器學習在線手冊深度學習筆記專輯《統計學習方法》的代碼復現專輯 AI基礎下載機器學習的數學基礎專輯獲取一折本站知識星球優惠券,復制鏈接直接打開:https://t.zsxq.com/yFQV7am本站qq群1003271085。加入微信群請掃碼進群:總結
以上是生活随笔為你收集整理的【CV】10分钟理解Focal loss数学原理与Pytorch代码的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【Python】用 Python 来实现
- 下一篇: 【Python】全网最新最全Pyecha