一、Focal Loss理论及代码实现
文章目錄
- 前言
- 一、基本理論
- 二、實現
- 1.公式
- 2.代碼實現
- 1.基于二分類交叉熵實現。
- 2.知乎大佬的實現
前言
本文參考:幾時見得清夢博主文章
參考原文:https://www.jianshu.com/p/30043bcc90b6
一、基本理論
1.采用soft - gamma: 在訓練的過程中階段性的增大gamma 可能會有更好的性能提升。
2.alpha 與每個類別在訓練數據中的頻率有關。
3.F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函數功能與F.cross_entropy相同。
F.nll_loss中實現了對于target的one-hot encoding,將其編碼成與input shape相同的tensor,然后與前面那一項(即F.nll_loss輸入的第一項)進行 element-wise production。
基于alpha=1采用不同的gamma值進行實驗的結果
4.focal loss解決了什么問題?
(1)不同類別不均衡
(2)難易樣本不均衡
5.在retinanet中,除了使用呢focal loss外,還對初始化做了特殊處理,具體是怎么做的?
在retinanet中,對 classification subnet 的最后一層conv設置它的偏置b為:
二、實現
1.公式
標準的Cross Entropy 和Focal Loss 為:
關于的前向與后向推導見知乎:https://zhuanlan.zhihu.com/p/32631517
2.代碼實現
1.基于二分類交叉熵實現。
# 1.基于二分類交叉熵實現
class FocalLoss(nn.Module):def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):super(FocalLoss, self).__init__()self.alpha = alphaself.gamma = gammaself.logits = logitsself.reduce = reducedef forward(self, inputs, targets):if self.logits:BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)else:BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)pt = torch.exp(-BCE_loss)F_loss = self.alpha * (1-pt)**self.gamma * BCE_lossif self.reduce:return torch.mean(F_loss)else:return F_loss
2.知乎大佬的實現
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variableclass FocalLoss(nn.Module):r"""This criterion is a implemenation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection.Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])The losses are averaged across observations for each minibatch.Args:alpha(1D Tensor, Variable) : the scalar factor for this criteriongamma(float, double) : gamma > 0; reduces the relative loss for well-classi?ed examples (p > .5), putting more focus on hard, misclassi?ed examplessize_average(bool): By default, the losses are averaged over observations for each minibatch.However, if the field size_average is set to False, the losses areinstead summed for each minibatch."""def __init__(self, class_num, alpha=None, gamma=2, size_average=True):super(FocalLoss, self).__init__()if alpha is None:self.alpha = Variable(torch.ones(class_num, 1))else:if isinstance(alpha, Variable):self.alpha = alphaelse:self.alpha = Variable(alpha)self.gamma = gammaself.class_num = class_numself.size_average = size_averagedef forward(self, inputs, targets):N = inputs.size(0)C = inputs.size(1)P = F.softmax(inputs)class_mask = inputs.data.new(N, C).fill_(0)class_mask = Variable(class_mask)ids = targets.view(-1, 1)class_mask.scatter_(1, ids.data, 1.)#print(class_mask)if inputs.is_cuda and not self.alpha.is_cuda:self.alpha = self.alpha.cuda()alpha = self.alpha[ids.data.view(-1)]probs = (P*class_mask).sum(1).view(-1,1)log_p = probs.log()#print('probs size= {}'.format(probs.size()))#print(probs)batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p #print('-----bacth_loss------')#print(batch_loss)if self.size_average:loss = batch_loss.mean()else:loss = batch_loss.sum()return loss
``
總結
以上是生活随笔為你收集整理的一、Focal Loss理论及代码实现的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 一、PyTorch Cookbook(常
- 下一篇: 一、迁移学习与fine-tuning有什