CrossEntropyLoss中ignore_index的坑:全忽略时,结果是nan

📅 2026/7/1 11:22:56
CrossEntropyLoss中ignore_index的坑:全忽略时,结果是nan
PyTorch CrossEntropyLoss 中ignore_index的完全解析一、现象描述在使用torch.nn.CrossEntropyLoss(ignore_index0)时你可能会遇到一个奇怪的现象如果batch 中所有样本的标签都是0即全部被忽略损失输出为nan。如果只有部分样本的标签是0损失可以正常计算。例如下面的代码importtorchimporttorch.nnasnn criterionnn.CrossEntropyLoss(ignore_index0)x1torch.tensor([[1.2,2.1],[2.1,0.5]])y1torch.tensor([0,1])# 部分忽略loss1criterion(x1,y1)print(loss1)# 输出 tensor(1.7839)x2torch.tensor([[2.1,0.5]])y2torch.tensor([0])# 全部忽略loss2criterion(x2,y2)print(loss2)# 输出 tensor(nan)这是怎么回事下面从机制到原因逐步剖析。二、ignore_index是如何“忽略”的ignore_index的工作方式不是将损失设为 0而是完全跳过对应样本。对于每个样本模型输出一个 logits 向量同时有一个真实标签target。如果target ignore_index那么这个样本不参与损失计算不计算交叉熵值不参与梯度回传相当于从当前 batch 中被“删除”具体到损失计算过程默认reductionmeanLoss (有效样本的损失之和) / (有效样本的数量)分子只累加未被忽略样本的损失分母只计数未被忽略样本的数量被忽略的样本对分子和分母都没有贡献。三、案例拆解为什么loss1正常loss2是nan案例 1部分忽略 → 正常y1 [0, 1]样本 1标签 0被忽略 → 不计入分子和分母样本 2标签 1有效 → 计算损失L2分子 L2分母 1Loss L2 / 1 L2→ 有限值案例 2全部忽略 →nany2 [0]样本 1标签 0被忽略 → 不计入分子和分母分子 0没有任何有效损失被加总分母 0没有任何有效样本Loss 0 / 0→ 数学上未定义 → PyTorch 返回nan这就是nan的直接原因空集求平均。四、不同reduction模式下的行为对比CrossEntropyLoss支持三种reduction模式。当所有样本都被忽略时它们的表现不同reduction 模式所有样本被忽略时的返回值原因mean(默认)nan0 / 0未定义sum0.0有效损失之和为 0无需除法nonetensor([0., 0., ...])每个被忽略样本的位置输出0.0示例代码criterion_meantorch.nn.CrossEntropyLoss(ignore_index0,reductionmean)criterion_sumtorch.nn.CrossEntropyLoss(ignore_index0,reductionsum)criterion_nonetorch.nn.CrossEntropyLoss(ignore_index0,reductionnone)xtorch.tensor([[2.1,0.5]])ytorch.tensor([0])print(criterion_mean(x,y))# nanprint(criterion_sum(x,y))# 0.0print(criterion_none(x,y))# tensor([0.])五、总结与核心避坑点ignore_index的实质将被忽略样本从**分子损失和和分母样本计数**中同时剔除。nan的触发条件当且仅当batch 内所有样本的标签都等于ignore_index时分母为 0在mean模式下产生nan。不同reduction的安全性使用reductionsum或none可以避免nan。但sum会改变损失尺度影响梯度大小none需要手动聚合。设计合理性0/0未定义返回nan是数学上的正确行为提醒开发者处理空 batch 情况。六、实际应用建议在训练中尤其是处理类别极度不平衡的数据集或数据增强如随机裁剪导致全是背景无效区域时很容易出现某个 batch 全部被忽略的情况。推荐做法在训练循环中检查nan并跳过该 batchcriteriontorch.nn.CrossEntropyLoss(ignore_index0)optimizertorch.optim.Adam(model.parameters())forinputs,targetsindataloader:outputsmodel(inputs)losscriterion(outputs,targets)iftorch.isnan(loss):# 跳过这个 batch不更新模型continueoptimizer.zero_grad()loss.backward()optimizer.step()这样可以避免nan传播污染模型参数。替代方案如果你确定 batch 中偶尔会出现全忽略的情况且希望训练不中断可以临时改用reductionsum或none但要记得调整学习率或手动处理聚合。更好的办法在数据加载时确保每个 batch 至少包含一个非忽略样本例如通过采样策略。