从 14B 到 1.5B基于 Top-K Logits 的 LLM 知识蒸馏实战Qwen2.5 全流程把 14B 大模型的能力塞进1.5B 小模型里显存怎么压速度怎么提教师 logits 怎么存本文用一个完整可跑的项目带你打通 Qwen2.5 系列的 Logits 离线蒸馏全链路。一、为什么要做 Logits 蒸馏1.1 痛点大模型推理成本高是业界共识。Qwen2.5-14B 能力强但部署成本高Qwen2.5-1.5B 轻量但能力弱。知识蒸馏Knowledge Distillation就是让学生模型在教师模型的指导下训练逼近教师性能。1.2 为什么选 Logits 蒸馏而非响应蒸馏主流蒸馏分两派方法教师输出信息量训练速度Logits 蒸馏本项目每个位置全词表的概率分布极大快离线预计算响应蒸馏教师生成的文本答案有限慢需在线生成Logits 蒸馏的核心优势是暗知识Dark Knowledge——教师模型对每个 token 的概率分布包含丰富的类间关系信息。比如做数学题时“divide” 和 “multiply” 的概率对比硬标签标准答案给不了但 logits 能给。而且教师 logits 可以离线预计算并存储训练阶段无需再跑教师模型大幅节省训练时间。二、项目概览教师模型Qwen2.5-14B-Instruct-AWQ4-bit 量化学生模型Qwen2.5-1.5B-Instruct数据集MetaMathQA数学应用题方法Top-K Logits 离线蒸馏硬件2× RTX 3090 (24GB)为什么教师/学生选同系列Qwen2.5 全系列0.5B 到 72B使用完全相同的 tokenizer词表 152,064。教师 logits 的 token 索引和学生 logits 的 token 索引天然对齐零额外处理。三、核心概念温度、Top-K 与暗知识3.1 温度Temperature的作用标准 Softmax (T1): p_i exp(logit_i / 1) / Σ exp(logit_j / 1) 高温 Softmax (T3): p_i exp(logit_i / 3) / Σ exp(logit_j / 3)温度越高概率分布越平滑暴露出更多非顶部的概率信息。本项目用T3.0。3.2 Top-K 稀疏存储的空间魔法Qwen2.5 词表 152,064每个位置完整 logits 在 fp16 下占 304KB。对于 5000 条样本 × 2048 位置全量存储需要约 2.9 TB。只保留概率最高的K50个 token全量: [N, 2048, 152064] → ~2.9 TB Top-K: [N, 2048, 50] → ~1.4 GB 节省: 99.97%而教师模型 Top-50 token 通常覆盖了99% 的概率质量信息几乎无损。四、全流程详解整个 pipeline 分四个阶段MetaMathQA │ ▼ [阶段0] prepare_data.py 分词 → input_ids.pt [N, 2048] │ ▼ [阶段1] generate_logits.py 教师推理 Top-K → top_indices.npy │ ▼ [阶段2] train_student.py 学生蒸馏训练 → final_model/ │ ▼ [阶段3] eval_compare.py 推理对比阶段 0数据准备load_dataset(meta-math/MetaMathQA)→ shuffle(seed42).select(range(5000))→ tokenizer(queries,max_length2048,paddingmax_length)→ 保存为 input_ids.pt/attention_mask.pt关键点把分词结果缓存成.pt文件避免教师推理和学生训练两个阶段重复分词同时保证输入完全对齐。阶段 1教师 Logits 生成这是项目最有工程价值的一步forbatch_startinrange(0,N,batch_size):input_ids_batchinput_ids[batch_start:batch_end].cuda()outputsmodel(input_ids_batch,attention_maskattention_mask_batch)logitsoutputs.logits# [bs, 2048, 152064]# 提取 Top-Kvalues,indicestorch.topk(logits,k50)# [bs, 2048, 50]# 立即释放显存deloutputs,logits torch.cuda.empty_cache()# 写入 mmap 文件top_indices[batch_start:batch_end]indices.cpu().numpy().astype(np.int32)top_values[batch_start:batch_end]values.cpu().numpy().astype(np.float16)# 更新进度save_progress(batch_end)三个工程亮点numpy mmap 预分配用modew预分配全量空间每个样本在文件中有唯一物理偏移量支持精确断点续传。断点续传progress.json记录最后完成索引重启后从断点继续不破坏已有数据。即时显存释放每个 batch 后delempty_cache把显存占用压到最低。阶段 2学生蒸馏训练蒸馏损失公式Loss α · T² · KL(p_teacher || p_student) (1 - α) · CE(p_student, y_true) └─────────────────────────────────┘ └────────────────────────────┘ 蒸馏项软标签 硬标签项标准训练参数α 0.7,T 3.0。为什么乘 T²温度 T 平滑了 logits 分布。根据链式法则KL 损失对 student logits 的梯度会多一个1/T因子高温下梯度量级急剧减小。乘 T² 补偿这个衰减使 KL 损失梯度与标准 CE 在同一数量级。Top-K 上的高效 KL 计算# 提取学生在教师 Top-K 位置对应的 logits避免创建全词表张量student_topktorch.gather(student_logits,dim-1,indexteacher_top_indices)p_teachersoftmax(teacher_top_values/T,dim-1)log_p_studentlog_softmax(student_topk/T,dim-1)kl_loss(p_teacher*(p_teacher.log()-log_p_student)).sum(dim-1).mean()kl_losskl_loss*T*T# 梯度补偿用torch.gather只在 Top-K 维度计算避免创建 [B, S, 152064] 的全张量计算和显存都大幅减少。五、踩坑实录重点推荐阅读项目过程踩了两个非常典型的坑写出来给大家避雷。5.1 cuBLAS fp16 GEMM Bug现象教师模型前向推理在lm_head层崩溃RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling cublasGemmEx(... a, CUDA_R_16F, ... b, CUDA_R_16F, ...)根因分析PyTorch 2.10.0cu128 └── 捆绑 nvidia-cublas-cu1212.8.4.1有 bug └── cublasGemmEx ├── CUDA_R_16F (fp16) ❌ ├── CUDA_R_16BF (bf16) ❌ └── CUDA_R_32F (fp32) ✅注意 AWQ 的 Marlin 量化内核不受影响它有自有 int4→fp16 反量化路径不走 cuBLAS。所以教师模型的量化层正常但lm_head标准nn.Linear会触发 bug。临时方案把教师lm_head包装为 fp32 计算class_Float32Linear(torch.nn.Module):defforward(self,x):returnF.linear(x.float(),self.weight)显存多占 ~1.5GBRTX 3090 承受得住。但学生模型全是标准 Linear无法逐个包装。根治方案升级 cuBLAS 到修复版本pipinstallnvidia-cublas-cu1212.9.1.4 --force-reinstallpip 会警告torch 2.10.0cu128 requires nvidia-cublas-cu1212.8.4.1这只是依赖声明警告实际运行完全正常。5.2 单卡 OOM 优化学生模型训练时原始配置batch_size4, max_seq_len2048导致显存爆炸Logits 大小 4 × 2048 × 152064 × 2 bytes ≈ 2.5 GB反向传播时这个大张量必须缓存24GB 显存瞬间爆掉。解法Micro-batch 梯度累积# config.pyper_device_batch_size1# 从 4 降到 1gradient_accumulation_steps8# 从 2 升到 8# 等效全局 batch size 1 × 8 8保持不变工作机制[ 样本 1 ] → 前向反向 → 梯度累加暂存 [ 样本 2 ] → 前向反向 → 梯度累加暂存 ... [ 样本 8 ] → 前向反向 → 梯度累加 → 【optimizer.step() 更新参数】显存占用降到原来的25%等效 batch size 不变模型收敛效果无影响。5.3 AWQ 加载方式变了transformers 4.48 弃用了autoawq改用gptqmodel作为 AWQ 加载后端pipinstallgptqmodel不要再装 autoawq——它已官方废弃且会强制降级 transformers 到 4.47.1与新版本冲突。六、关键工程设计总结决策选择原因教师/学生同系列Qwen2.5 全家桶词表 100% 一致零额外处理存储格式numpy mmap训练时随机读取高效不需全部加载内存Top-K50稀疏存储节省 ~99.97% 存储空间KL 计算torch.gather只在 Top-K 维度计算避免 [B,S,152K] 全张量断点续传progress.jsonlogits 生成耗时长支持中断恢复mmap 懒加载__getitem__内_setup_mmap()解决多 worker DataLoader 的 pickle 问题梯度累积micro-batch1, accum8单卡跑得动大词表模型mmap 懒加载的小技巧np.memmap对象包含操作系统文件描述符不可被 pickle 序列化。DataLoadernum_workers 0时会把 Dataset 通过 pickle 分发给子进程直接初始化会报PicklingError。解决方案构造函数中设_top_indices None在__getitem__首次调用时才打开 mmap——每个 worker 进程独立打开自己的文件描述符。七、评估效果蒸馏成功的标志✅ 蒸馏模型回答更结构化有步骤、有公式、有\boxed{}答案✅减少幻觉不会凭空编造不合理假设✅ 回答风格更接近 14B 教师模型蒸馏不能解决的❌ 算术计算能力1.5B 参数量的固有限制❌ 教师模型本身不具备的知识八、一键运行# 环境conda create-nlogits-distillpython3.10conda activate logits-distill pipinstalltorch2.10.0torchvision0.25.0\--index-url https://download.pytorch.org/whl/cu128 pipinstall-rrequirements.txt# 国内镜像必要exportHF_ENDPOINThttps://hf-mirror.com# 全流程bashrun_all.sh# 或分步python prepare_data.py python generate_logits.py accelerate launch--num_processes2train_student.py python eval_compare.py写在最后Logits 蒸馏看似原理简单KL 散度 温度但工程实现里的细节决定了能不能跑通存储压缩Top-K 稀疏存储让 TB 级数据降到 GB 级显存优化mmap 梯度累积让单卡 3090 也能跑大词表模型环境陷阱cuBLAS bug、AWQ 后端变更等坑要提前规避希望这个项目能帮到你。如有问题欢迎评论区交流。