交叉熵是机器学习中用于衡量两个概率分布之间差异的一种损失函数,广泛用于分类问题(如逻辑回归、多分类神经网络)。交叉熵损失直接来源于信息论,特别适合概率建模问题。
1. 定义
交叉熵公式
给定真实分布 P 和预测分布 Q,交叉熵定义为:
- P(i):真实分布中类别 i 的概率。
- Q(i):预测分布中类别 i 的概率。
直观理解
- 当 Q(i) 越接近 P(i),交叉熵值越小。
- 如果预测 Q(i) 偏离真实分布 P(i),损失会快速增大。
2. 在分类问题中的应用
二分类问题
对于目标变量 ,预测类别为 1 的概率记为
,则交叉熵损失为:
- 直观解释:
- 如果
,只保留
,表示预测为正类的概率是否接近 1。
- 如果
,只保留
,表示预测为负类的概率是否接近 1。
- 如果
多分类问题
对于目标变量 ,预测概率分布为
,真实类别的独热编码为
,则交叉熵损失为:
:指示函数,表示样本 i 是否属于类别 k。
:模型预测样本 i 属于类别 k 的概率。
3. 与其他概念的关系
与负对数似然的关系
交叉熵损失是负对数似然(Negative Log-Likelihood, NLL)的具体形式,特别是对于分类问题:
- 二分类交叉熵 = 负对数似然损失。
- 多分类交叉熵 = Softmax + 负对数似然。
与信息论的关系
- 交叉熵来源于信息熵:
- 信息熵(Entropy):描述单一分布的不确定性。
- 交叉熵:衡量一个分布 P 如何被另一个分布 Q 表示。
- 信息熵(Entropy):描述单一分布的不确定性。
与 Kullback-Leibler 散度的关系
交叉熵和 KL 散度的关系为:
- H(P):真实分布 P 的信息熵。
:预测分布 Q 与真实分布 P 的差异。
. 性质与特点
优点
- 概率解释清晰:直接衡量预测分布与真实分布的差异。
- 数值敏感:能捕捉预测分布与目标分布之间的细微差别。
- 凸优化:在许多情况下,交叉熵损失是凸的,便于优化。
缺点
- 对错误预测敏感:如果预测概率接近 0,
会趋向无穷大。
- 数值问题:需对
进行裁剪,如
,以避免
错误。
5. 代码实现
二分类交叉熵损失
import numpy as npdef binary_cross_entropy(y_true, y_pred):y_pred = np.clip(y_pred, 1e-10, 1 - 1e-10) # 防止 log(0)return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))# 示例
y_true = np.array([1, 0, 1, 0])
y_pred = np.array([0.9, 0.1, 0.8, 0.4])
loss = binary_cross_entropy(y_true, y_pred)
print("Binary Cross-Entropy Loss:", loss)
输出结果
Binary Cross-Entropy Loss: 0.23617255159896325
多分类交叉熵损失
from sklearn.metrics import log_loss# 示例
y_true = [0, 2, 1, 2] # 真实类别
y_pred = [[0.9, 0.05, 0.05], # 预测概率分布[0.1, 0.1, 0.8],[0.2, 0.7, 0.1],[0.05, 0.1, 0.85]
]loss = log_loss(y_true, y_pred)
print("Multi-class Cross-Entropy Loss:", loss)
输出结果
Multi-class Cross-Entropy Loss: 0.2119244851021358
6. 交叉熵损失的应用
-
逻辑回归:
- 使用二分类交叉熵来优化模型参数。
-
神经网络:
- 输出层结合 Sigmoid(用于二分类)或 Softmax(用于多分类)激活,常配合交叉熵损失。
-
生成模型:
- 衡量生成分布与真实分布的差异。
-
概率分布拟合:
- 对预测分布进行监督,提升模型对类别概率的建模能力。
交叉熵作为一种衡量概率分布差异的工具,不仅在分类任务中表现优异,还为概率建模问题提供了坚实的理论基础。