Focal Loss详解及其pytorch实现
文章目录
- Focal Loss详解及其pytorch实现
- 引言
- 二分类与多分类的交叉熵损失函数
- 二分类交叉熵损失
- 多分类交叉熵损失
- Focal Loss基础概念
- 关键点理解
- 什么是难分类样本和易分类样本?
- 超参数 γ \gamma γ 的作用
- 超参数 α \alpha α 的作用
- 超参数 α \alpha α 的详细解释
- 实例计算
- 负样本实例
- 多分类实例
- PyTorch实现
- 二分类Focal Loss
- 多分类Focal Loss
- 结论
- 参考文献
引言
Focal Loss是由何恺明等人在2017年的论文《Focal Loss for Dense Object Detection》中提出的。它是一种专门为解决目标检测中类别不平衡和难易样本不平衡问题而设计的损失函数。本文将详细介绍Focal Loss的基本概念、二分类和多分类的交叉熵损失函数,以及如何设置Focal Loss中的关键参数,并提供PyTorch的实现代码。
二分类与多分类的交叉熵损失函数
二分类交叉熵损失
在二分类的任务中,一般使用Sigmoid作为最后的激活函数,输出代表样本为正的概率值 y ^ \hat{y} y^,二分类非正即负,所以样本为负的概率值为 1 − y ^ 1-\hat{y} 1−y^。二分类交叉熵损失的计算公式为:
CEL = − y ⋅ log ( y ^ ) − ( 1 − y ) ⋅ log ( 1 − y ^ ) \text{CEL} = -y \cdot \log(\hat{y}) - (1-y) \cdot \log(1-\hat{y}) CEL=−y⋅log(y^)−(1−y)⋅log(1−y^)
其中 y y y 是实际标签,正样本为1,负样本为0, y ^ \hat{y} y^ 是Sigmoid激活函数的输出值。
多分类交叉熵损失
在多分类的情况下,一般使用Softmax作为最后的激活函数,输出有多个值,对应每个分类的概率值,且这些值之和为1。多分类交叉熵损失的计算公式为:
CEL = − ∑ c = 1 C y c ⋅ log ( y ^ c ) = − log ( y ^ c ) \text{CEL} = -\sum_{c=1}^{C} y_c \cdot \log(\hat{y}_c) = -\log(\hat{y}_c) CEL=−c=1∑Cyc⋅log(y^c)=−log(y^c)
其中 y ^ c \hat{y}_c y^c 表示Softmax激活函数输出结果中第 c c c 类的对应的值, C C C 是类别的总数。
Focal Loss基础概念
关键点理解
要真正理解Focal Loss,有三个关键点需要明确:
- 二分类(Sigmoid)和多分类(Softmax)的交叉熵损失表达形式的区别。
- 理解难分类样本与易分类样本。
- Focal Loss中的超参数 α \alpha α 和 γ \gamma γ 的作用。
什么是难分类样本和易分类样本?
- 易分类样本:模型预测正确的概率较高,即 y ^ t \hat{y}_t y^t 较大(通常 y ^ t > 0.5 \hat{y}_t > 0.5 y^t>0.5)。
- 难分类样本:模型预测正确的概率较低,即 y ^ t \hat{y}_t y^t 较小(通常 y ^ t < 0.5 \hat{y}_t < 0.5 y^t<0.5)。
其中 y ^ t \hat{y}_t y^t 定义为:
y ^ t = { y ^ if y = 1 1 − y ^ otherwise \hat{y}_t = \begin{cases} \hat{y} & \text{if } y = 1 \\ 1 - \hat{y} & \text{otherwise} \end{cases} y^t={y^1−y^if y=1otherwise
超参数 γ \gamma γ 的作用
超参数 γ \gamma γ 控制了难分类样本和易分类样本在损失函数中的比重。Focal Loss相对于原始的交叉熵损失增加了 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1−y^t)γ 这一项,对原始交叉熵损失进行了衰减。当 γ \gamma γ 增大时,对易分类样本的损失衰减更加明显,从而使模型更加关注难分类样本。
超参数 α \alpha α 的作用
超参数 α \alpha α 用于调整正负样本之间的权重。在二分类中, α \alpha α 的值反映了样本数量较少的类的权重。通常情况下,正样本数量较少(在本文中正样本代表数量少的样本),因此 α \alpha α 值反映了正样本的权重。随着 γ \gamma γ 的增加, α \alpha α 应该稍微降低。这是因为:
- 低 α \alpha α 对应高 γ \gamma γ。负样本通常容易被正确分类,其权重已经被 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1−y^t)γ 显著降低,因此无需给正样本再增加额外过大的权重 α \alpha α。
- 在Focal Loss中, γ \gamma γ 占主要地位,它确保了模型更加关注那些难以正确分类的样本。
- 当处理负样本时, α \alpha α 的值通常为 1 − α 1 - \alpha 1−α,其中 α \alpha α 为正样本的权重。
超参数 α \alpha α 的详细解释
在Focal Loss中, α \alpha α 的作用是调整正负样本之间的权重。理论上,数量越少的类应该具有更大的权重。然而,在原论文作者的实验中,当 α = 0.25 \alpha = 0.25 α=0.25 和 γ = 2 \gamma = 2 γ=2 时,模型表现最好。这引发了一个问题:为什么正样本的权重( α = 0.25 \alpha = 0.25 α=0.25)反而比负样本的权重( 1 − α = 0.75 1 - \alpha = 0.75 1−α=0.75)要低,尤其是当负样本的数量远远多于正样本时?
这是因为Focal Loss的设计初衷是为了减少易分类样本的贡献,让模型更加关注难分类样本。随着 γ \gamma γ 的增加,难分类样本的权重实际上已经被显著提高了。此外,由于负样本通常更容易被正确分类,其权重已经被 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1−y^t)γ 大幅降低,因此不需要再额外增加正样本的权重。这意味着,在Focal Loss中, γ \gamma γ 的作用更为关键,而 α \alpha α 的作用则相对次要。
实例计算
假设我们有一个正样本,模型预测的概率为0.8,取 γ = 2 \gamma = 2 γ=2。
-
计算 y ^ t \hat{y}_t y^t:
y ^ t = y ^ = 0.8 \hat{y}_t = \hat{y} = 0.8 y^t=y^=0.8 -
计算Focal Loss:
FL ( y ^ t ) = − α t ⋅ ( 1 − 0.8 ) 2 ⋅ log ( 0.8 ) \text{FL}(\hat{y}_t) = -\alpha_t \cdot (1 - 0.8)^2 \cdot \log(0.8) FL(y^t)=−αt⋅(1−0.8)2⋅log(0.8)
若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = 0.25 \alpha_t = 0.25 αt=0.25,因此:
FL ( y ^ t ) = − 0.25 ⋅ ( 0.2 ) 2 ⋅ log ( 0.8 ) ≈ − 0.25 ⋅ 0.04 ⋅ ( − 0.22314 ) ≈ 0.00223 \text{FL}(\hat{y}_t) = -0.25 \cdot (0.2)^2 \cdot \log(0.8) \approx -0.25 \cdot 0.04 \cdot (-0.22314) \approx 0.00223 FL(y^t)=−0.25⋅(0.2)2⋅log(0.8)≈−0.25⋅0.04⋅(−0.22314)≈0.00223
负样本实例
假设我们有一个负样本,模型预测的概率为0.2,取 γ = 2 \gamma = 2 γ=2。
-
计算 y ^ t \hat{y}_t y^t:
y ^ t = 1 − y ^ = 1 − 0.2 = 0.8 \hat{y}_t = 1 - \hat{y} = 1 - 0.2 = 0.8 y^t=1−y^=1−0.2=0.8 -
计算Focal Loss:
FL ( y ^ t ) = − α t ⋅ ( 1 − 0.8 ) 2 ⋅ log ( 0.8 ) \text{FL}(\hat{y}_t) = -\alpha_t \cdot (1 - 0.8)^2 \cdot \log(0.8) FL(y^t)=−αt⋅(1−0.8)2⋅log(0.8)
若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = 1 − 0.25 = 0.75 \alpha_t = 1 - 0.25 = 0.75 αt=1−0.25=0.75,因此:
FL ( y ^ t ) = − 0.75 ⋅ ( 0.2 ) 2 ⋅ log ( 0.8 ) ≈ − 0.75 ⋅ 0.04 ⋅ ( − 0.22314 ) ≈ 0.00669 \text{FL}(\hat{y}_t) = -0.75 \cdot (0.2)^2 \cdot \log(0.8) \approx -0.75 \cdot 0.04 \cdot (-0.22314) \approx 0.00669 FL(y^t)=−0.75⋅(0.2)2⋅log(0.8)≈−0.75⋅0.04⋅(−0.22314)≈0.00669
多分类实例
假设我们有三个类别(猫、狗、兔子),模型预测的概率分别为 [ 0.2 , 0.5 , 0.3 ] [0.2, 0.5, 0.3] [0.2,0.5,0.3],实际标签是狗(one-hot编码为[0, 1, 0]),取 γ = 2 \gamma = 2 γ=2。
-
计算 y ^ c \hat{y}_c y^c:
y ^ c = y ^ 2 = 0.5 \hat{y}_c = \hat{y}_2 = 0.5 y^c=y^2=0.5 -
计算Focal Loss:
FL ( y ^ 2 ) = − α 2 ⋅ ( 1 − 0.5 ) 2 ⋅ log ( 0.5 ) \text{FL}(\hat{y}_2) = -\alpha_2 \cdot (1 - 0.5)^2 \cdot \log(0.5) FL(y^2)=−α2⋅(1−0.5)2⋅log(0.5)
若取 α 2 = 0.25 \alpha_2 = 0.25 α2=0.25,则:
FL ( y ^ 2 ) = − 0.25 ⋅ ( 0.5 ) 2 ⋅ log ( 0.5 ) ≈ − 0.25 ⋅ 0.25 ⋅ ( − 0.69315 ) ≈ 0.04332 \text{FL}(\hat{y}_2) = -0.25 \cdot (0.5)^2 \cdot \log(0.5) \approx -0.25 \cdot 0.25 \cdot (-0.69315) \approx 0.04332 FL(y^2)=−0.25⋅(0.5)2⋅log(0.5)≈−0.25⋅0.25⋅(−0.69315)≈0.04332
PyTorch实现
二分类Focal Loss
import torchclass FocalLossBinary(torch.nn.Module):"""二分类Focal Loss"""def __init__(self, alpha=0.25, gamma=2):super(FocalLossBinary, self).__init__()self.alpha = alphaself.gamma = gammadef forward(self, preds, labels):"""preds: sigmoid的输出结果labels: 标签"""eps = 1e-7loss_1 = -1 * self.alpha * torch.pow((1 - preds), self.gamma) * torch.log(preds + eps) * labelsloss_0 = -1 * (1 - self.alpha) * torch.pow(preds, self.gamma) * torch.log(1 - preds + eps) * (1 - labels)loss = loss_0 + loss_1return torch.mean(loss)
多分类Focal Loss
import torchclass FocalLossMultiClass(torch.nn.Module):def __init__(self, weight=None, gamma=2):super(FocalLossMultiClass, self).__init__()self.gamma = gammaself.weight = weightdef forward(self, preds, labels):"""preds: softmax输出结果labels: 真实值"""eps = 1e-7y_pred = preds.view((preds.size()[0], preds.size()[1], -1)) # B*C*H*W->B*C*(H*W)target = labels.view(y_pred.size()) # B*C*H*W->B*C*(H*W)ce = -1 * torch.log(y_pred + eps) * targetfloss = torch.pow((1 - y_pred), self.gamma) * ceif self.weight is not None:floss = torch.mul(floss, self.weight)floss = torch.sum(floss, dim=1)return torch.mean(floss)
结论
Focal Loss通过引入两个超参数 α \alpha α 和 γ \gamma γ,有效地解决了类别不平衡和难易样本不平衡的问题。通过调整这些超参数,可以使模型更加关注那些难以正确分类的样本,从而提高整体性能。在实际应用中,可以通过实验来确定最佳的 α \alpha α 和 γ \gamma γ 值。
参考文献
Focal Loss的理解以及在多分类任务上的使用(Pytorch) -
GHZhao_GIS_RS - CSDN