Mean-Teacher 均值教师自训练框架详解一、Mean-Teacher 核心原理详细描述1. 设计背景半监督学习场景少量标注数据 大量无标注数据。传统伪标签方法单模型预测噪声大、易过拟合无标注样本Mean-Teacher 提出双模型架构学生模型 教师模型用教师平滑预测作为软伪标签监督学生大幅降低伪标签噪声。2. 两大核心模型Student 学生网络可训练、带梯度更新输入加随机数据增强强增广同时接收标注/无标注样本损失包含监督损失 一致性正则损失。Teacher 教师网络无梯度、不直接训练权重由学生权重指数移动平均EMA缓慢更新θtα⋅θt(1−α)⋅θs\theta_t \alpha \cdot \theta_t (1-\alpha) \cdot \theta_sθtα⋅θt(1−α)⋅θsα\alphaαEMA衰减系数常用0.99~0.999训练前期可退火逐步提升。输入仅弱数据增强输出作为稳定、平滑的目标标签用来约束学生模型输出分布一致。3. 完整损失函数总损失 标注监督分类损失 无标注一致性损失LtotalLsupλ(t)⋅Lconsist \mathcal{L}_{total} \mathcal{L}_{sup} \lambda(t) \cdot \mathcal{L}_{consist}LtotalLsupλ(t)⋅Lconsist监督损失Lsup\mathcal{L}_{sup}Lsup仅作用有标签样本交叉熵分类损失一致性损失Lconsist\mathcal{L}_{consist}LconsistMSE/KL散度约束学生强增广输出分布 ≈ 教师弱增广输出分布λ(t)\lambda(t)λ(t)一致性权重退火系数训练前期权重小后期放大正则约束避免前期噪声干扰。4. 训练流程步骤初始化学生网络、教师网络复制学生初始权重关闭教师梯度每迭代取一批混合数据有标签无标签样本分支处理标注样本弱增强送入学生计算分类交叉熵无标注样本强增强输入学生弱增强输入教师前向传播学生输出带噪声强增广预测教师输出停止梯度平滑稳定预测伪标签目标计算一致性MSE损失叠加监督损失反向传播只更新学生网络EMA更新教师权重无梯度迭代至收敛推理只用教师模型泛化更好。5. 关键创新点总结EMA教师权重缓慢平滑权重预测更稳定缓解伪标签漂移强弱双增强分离学生强增广提升鲁棒性教师弱增广保证目标可靠一致性正则化利用无标注数据约束模型输出不变性损失权重退火解决训练初期模型预测不可靠的问题。二、PyTorch 代码完整代码MNIST半监督任务少量标注大量无标注importnumpyasnpimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorch.utils.dataimportDataset,DataLoader,Subsetfromtorchvisionimportdatasets,transformsfromtqdmimporttqdm# 超参数配置 DEVICEtorch.device(cudaiftorch.cuda.is_available()elsecpu)EPOCHS50BATCH_SIZE128LR0.002EMA_ALPHA0.99# EMA衰减系数LAMBDA_MAX10.0# 一致性损失最大权重ANNEAL_STEPSEPOCHS*500# 退火总步数NUM_LABELED1000# 仅使用1000张标注MNIST其余为无标注# 1. 简单CNN backbone学生/教师共用 classConvNet(nn.Module):def__init__(self):super().__init__()self.conv1nn.Conv2d(1,32,3,padding1)self.conv2nn.Conv2d(32,64,3,padding1)self.poolnn.MaxPool2d(2,2)self.fc1nn.Linear(64*7*7,128)self.fc2nn.Linear(128,10)defforward(self,x):xself.pool(F.relu(self.conv1(x)))xself.pool(F.relu(self.conv2(x)))xtorch.flatten(x,1)xF.relu(self.fc1(x))logitsself.fc2(x)returnlogits# 2. EMA 教师更新工具函数 defupdate_teacher(student_model,teacher_model,alpha):fors_param,t_paraminzip(student_model.parameters(),teacher_model.parameters()):t_param.dataalpha*t_param.data(1.0-alpha)*s_param.data# 计算一致性损失权重退火defget_consistency_weight(current_step):# sigmoid退火0~LAMBDA_MAXrampupnp.exp(-5.0*(1.0-current_step/ANNEAL_STEPS)**2)returnLAMBDA_MAX*rampup# 3. 数据增强弱增强 强增强 weak_augtransforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])strong_augtransforms.Compose([transforms.RandomHorizontalFlip(p0.5),transforms.RandomAffine(degrees10,translate(0.1,0.1)),transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])# 混合数据集区分标注/无标注样本classSemiMNIST(Dataset):def__init__(self,full_dataset,labeled_mask):self.datafull_dataset self.labeled_masklabeled_mask# True有标签, False无标签def__len__(self):returnlen(self.data)def__getitem__(self,idx):img,labelself.data[idx]is_labeledself.labeled_mask[idx]ifis_labeled:img_weakweak_aug(img)img_strongstrong_aug(img)returnimg_weak,img_strong,label,is_labeledelse:img_weak_tweak_aug(img)# 教师输入弱增强img_strong_sstrong_aug(img)# 学生输入强增强returnimg_strong_s,img_weak_t,-1,is_labeled# 4. 构建半监督MNIST数据集 defbuild_semi_mnist():train_fulldatasets.MNIST(root./data,trainTrue,downloadTrue,transformtransforms.ToTensor())total_trainlen(train_full)# 随机划分标注/无标注all_indicesnp.arange(total_train)np.random.shuffle(all_indices)labeled_idxall_indices[:NUM_LABELED]labeled_masknp.zeros(total_train,dtypebool)labeled_mask[labeled_idx]Truesemi_trainSemiMNIST(train_full,labeled_mask)test_setdatasets.MNIST(root./data,trainFalse,transformweak_aug)train_loaderDataLoader(semi_train,batch_sizeBATCH_SIZE,shuffleTrue,num_workers0)test_loaderDataLoader(test_set,batch_sizeBATCH_SIZE,shuffleFalse,num_workers0)returntrain_loader,test_loader# 5. 训练主逻辑 deftrain_mean_teacher():# 初始化双网络studentConvNet().to(DEVICE)teacherConvNet().to(DEVICE)# 教师初始权重复制学生冻结梯度fort_paraminteacher.parameters():t_param.requires_gradFalse# 优化器仅更新学生opttorch.optim.Adam(student.parameters(),lrLR)train_loader,test_loaderbuild_semi_mnist()global_step0forepochinrange(EPOCHS):student.train()total_loss_epoch0.0sup_loss_epoch0.0cons_loss_epoch0.0pbartqdm(train_loader,descfEpoch{epoch1}/{EPOCHS})forbatchinpbar:opt.zero_grad()# 拆分批次数据x1,x2,labels,is_labeledbatch x1,x2,labelsx1.to(DEVICE),x2.to(DEVICE),labels.to(DEVICE)batch_sizex1.shape[0]# 前向传播学生网络全部样本强增广输入学生student_logitsstudent(x1)withtorch.no_grad():# 教师仅弱增广输入停止梯度teacher_logitsteacher(x2)teacher_probsF.softmax(teacher_logits,dim1)# 1. 监督损失仅标注样本sup_maskis_labeled.to(DEVICE)sup_loss0.0iftorch.sum(sup_mask)0:sup_logitsstudent_logits[sup_mask]sup_labelslabels[sup_mask]sup_lossF.cross_entropy(sup_logits,sup_labels)# 2. 一致性损失全部样本标注无标注都约束分布一致student_probsF.softmax(student_logits,dim1)cons_lossF.mse_loss(student_probs,teacher_probs.detach())# 3. 加权总损失cons_weightget_consistency_weight(global_step)total_losssup_losscons_weight*cons_loss# 反向传播更新学生total_loss.backward()opt.step()# EMA平滑更新教师权重update_teacher(student,teacher,EMA_ALPHA)# 统计损失total_loss_epochtotal_loss.item()sup_loss_epochsup_loss.item()cons_loss_epochcons_loss.item()*cons_weight global_step1pbar.set_postfix({total_loss:f{total_loss.item():.4f},sup_loss:f{sup_loss.item():.4f},cons_loss:f{cons_weight*cons_loss.item():.4f},cons_w:f{cons_weight:.2f}})# 打印epoch平均损失avg_totaltotal_loss_epoch/len(train_loader)avg_supsup_loss_epoch/len(train_loader)avg_conscons_loss_epoch/len(train_loader)print(f\n[Epoch{epoch1}] Avg Loss: total{avg_total:.4f}, sup{avg_sup:.4f}, cons{avg_cons:.4f})# 测试阶段使用教师模型评估泛化更强teacher.eval()correct0total0withtorch.no_grad():forimg,labintest_loader:img,labimg.to(DEVICE),lab.to(DEVICE)logitsteacher(img)predtorch.argmax(logits,dim1)correct(predlab).sum().item()totallab.size(0)acc100*correct/totalprint(fTest Acc (Teacher Model):{acc:.2f}%\n)if__name____main__:train_mean_teacher()三、代码模块逐段解释1. 网络模块 ConvNet轻量2层CNN分类网络学生、教师完全同结构权重独立初始化后复制。2. EMA教师更新update_teacher仅在学生参数更新后执行无梯度参与教师权重缓慢跟随学生平滑移动避免单步梯度剧烈波动带来的预测噪声。3. 一致性权重退火get_consistency_weight训练前期一致性权重趋近0优先学习标注数据训练中后期权重上升利用大量无标注数据做一致性正则防止前期模型预测不可靠带来错误监督。4. 双增强策略弱增强轻微归一化给教师输入保证目标标签稳定强增强翻转、仿射变换给学生输入迫使模型对图像扰动保持输出一致。5. 数据集 SemiMNIST自动区分标注/无标注样本对两类样本分别返回对应强弱增强图像统一送入训练循环。6. 损失计算细节监督损失仅在有标签样本上计算交叉熵一致性MSE损失对全部样本生效有标签样本也会额外加一致性约束进一步提升鲁棒性教师输出全程torch.no_grad()不产生梯度仅作为固定目标。7. 推理规则测试时只用教师模型EMA平滑后的权重泛化性能显著优于实时更新的学生模型是Mean-Teacher标准推理方案。四、调优关键技巧EMA α分类任务0.99~0.999数值越大教师更新越慢、预测越平滑一致性权重LAMBDA_MAX图像分类常用5~20过小无正则效果过大训练震荡数据增强差距学生增强越强一致性约束收益越高训练策略前期降低学习率、缓慢提升一致性权重防止模型崩溃损失替换可将MSE替换KL散度对分类概率分布约束效果接近。五、运行效果说明数据集MNIST仅1000张标注其余59000张无标注基线只用1000标注无自训练测试精度约75%~82%Mean-Teacher训练后教师模型测试精度可达93%充分验证半监督一致性正则收益。