llamafactory gradient_checkpointing 梯度检查点 通俗完整讲解

📅 2026/6/26 1:33:47
llamafactory gradient_checkpointing 梯度检查点 通俗完整讲解
llamafactory gradient_checkpointing 梯度检查点 通俗完整讲解1. 原生不开启时你现在 false 的状态模型走一遍前向传播会把所有中间激活值全部存在显存里。作用反向传播算梯度时直接拿这些中间值计算不用重新跑前向速度快。代价序列越长、batch越大中间激活占的显存爆炸很容易冲到99%。举个你场景的例子cutoff_len1024、bs3214B大模型中间激活张量体积非常大这就是你显存经常顶满的核心元凶。2. 开启 gradient_checkpointing: true 做了什么核心逻辑以少量重复计算换取大幅显存节省前向传播时不保存全部中间激活只存少量关键节点等到反向传播需要某一段中间数据时重新再跑一遍对应区间的前向计算现场算出激活全程不会一次性堆海量中间张量在显存里显存峰值直接下降 30%50%。和你梯度累积完全无关不管 gradient_accumulation_steps 是2还是4梯度检查点只管「前向激活要不要存」不影响梯度累加逻辑。3. 优缺点优点大幅压低显存峰值解决你 88%99% 冲高、偶尔OOM崩溃问题同样显卡下可以开更大 batch / 更长序列长度搭配 FlashAttention2 双重省显存MI300 192GB 体验提升明显。缺点每一步训练会多跑一小段前向计算训练速度大概慢 10%20%CPU算力、读写开销轻微上涨preprocessing_num_workers8完全能扛住。4. 适配你当前场景怎么选现状显存经常冲到99%波动极大→建议开启 true稳定性优先轻微降速完全可接受如果你后续调小batch、且全程显存稳定80以内不报警再切回 false 提速。5. 补充关键细节SFT训练专用use_cache: false必须搭配梯度检查点一起用你配置里已经开了正确训练阶段禁用KV缓存额外省一大块显存LoRA微调场景下开启梯度检查点不影响LoRA梯度更新只冻结基础模型部分重算LoRA训练效果无损失和 bf16 / FlashAttention2 兼容ROCm MI300无兼容性bug。极简总结不开存所有中间激活 → 显存占用高、跑的快开启丢掉大部分中间激活反向时临时重算 → 显存砍半、速度略慢专门解决你长文本1024序列导致的显存爆满问题。