别再死磕CNN了!手把手教你用PyTorch从零搭建ViT模型(附完整代码)

📅 2026/7/1 9:12:30
别再死磕CNN了!手把手教你用PyTorch从零搭建ViT模型(附完整代码)
从零构建ViT模型PyTorch实战指南与代码解析在计算机视觉领域Transformer架构正掀起一场革命。2020年Google提出的Vision TransformerViT打破了CNN在图像处理领域的长期垄断证明了纯Transformer架构在视觉任务中的强大潜力。本文将带你从零开始用PyTorch完整实现一个ViT模型避开理论复述直击代码实现中的关键细节与实战技巧。1. 环境准备与数据加载1.1 安装必要依赖首先确保你的Python环境已安装PyTorch。推荐使用conda创建虚拟环境conda create -n vit_env python3.8 conda activate vit_env pip install torch torchvision torchaudio pip install numpy matplotlib tqdm1.2 数据预处理策略ViT对输入图像有特定要求——通常为224x224分辨率。我们使用torchvision的预处理管道from torchvision import transforms train_transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])注意ImageNet的均值和标准差是ViT论文使用的默认值即使你使用其他数据集也建议保留这些值2. Patch Embedding实现技巧2.1 卷积实现方案ViT的核心创新是将图像分割为固定大小的patch。在代码中这可以通过一个巧妙设计的卷积层实现import torch.nn as nn class PatchEmbedding(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, embed_dim768): super().__init__() self.proj nn.Conv2d(in_chans, embed_dim, kernel_sizepatch_size, stridepatch_size) def forward(self, x): x self.proj(x) # [B, C, H, W] - [B, E, H/P, W/P] x x.flatten(2) # [B, E, H/P, W/P] - [B, E, N] x x.transpose(1, 2) # [B, E, N] - [B, N, E] return x关键参数说明参数值作用img_size224输入图像尺寸patch_size16每个patch的像素数embed_dim768每个patch的嵌入维度2.2 Class Token与位置编码ViT需要添加两个特殊元素class ViTEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.patch_embeddings PatchEmbedding() self.cls_token nn.Parameter(torch.zeros(1, 1, config.hidden_size)) self.position_embeddings nn.Parameter( torch.zeros(1, num_patches 1, config.hidden_size)) def forward(self, x): batch_size x.shape[0] embeddings self.patch_embeddings(x) cls_tokens self.cls_token.expand(batch_size, -1, -1) embeddings torch.cat((cls_tokens, embeddings), dim1) embeddings embeddings self.position_embeddings return embeddings3. Transformer Encoder构建3.1 多头注意力实现ViT的核心组件是Transformer Encoder。先实现多头注意力机制class MultiHeadAttention(nn.Module): def __init__(self, config): super().__init__() self.num_heads config.num_heads self.attention_head_size config.hidden_size // config.num_heads self.query nn.Linear(config.hidden_size, config.hidden_size) self.key nn.Linear(config.hidden_size, config.hidden_size) self.value nn.Linear(config.hidden_size, config.hidden_size) def forward(self, x): q self.query(x) k self.key(x) v self.value(x) # 分割为多头 q q.view(q.size(0), -1, self.num_heads, self.attention_head_size) k k.view(k.size(0), -1, self.num_heads, self.attention_head_size) v v.view(v.size(0), -1, self.num_heads, self.attention_head_size) # 注意力得分计算 attention_scores torch.matmul(q, k.transpose(-1, -2)) attention_scores attention_scores / math.sqrt(self.attention_head_size) attention_probs nn.Softmax(dim-1)(attention_scores) # 上下文向量计算 context torch.matmul(attention_probs, v) context context.permute(0, 2, 1, 3).contiguous() new_context_shape context.size()[:-2] (self.num_heads * self.attention_head_size,) context context.view(*new_context_shape) return context3.2 完整Encoder Block结合LayerNorm和MLP构建完整Encoderclass TransformerBlock(nn.Module): def __init__(self, config): super().__init__() self.attention MultiHeadAttention(config) self.intermediate nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size), nn.GELU(), nn.Linear(config.intermediate_size, config.hidden_size) ) self.layernorm1 nn.LayerNorm(config.hidden_size) self.layernorm2 nn.LayerNorm(config.hidden_size) def forward(self, x): # 自注意力 attention_output self.attention(x) x x attention_output x self.layernorm1(x) # 前馈网络 intermediate_output self.intermediate(x) x x intermediate_output x self.layernorm2(x) return x4. 模型训练与调优4.1 学习率策略ViT训练对学习率敏感推荐使用余弦退火from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR optimizer AdamW(model.parameters(), lr3e-4, weight_decay0.01) scheduler CosineAnnealingLR(optimizer, T_maxnum_epochs)4.2 混合精度训练为加速训练并减少显存占用建议启用混合精度from torch.cuda.amp import GradScaler, autocast scaler GradScaler() for epoch in range(num_epochs): for images, labels in train_loader: optimizer.zero_grad() with autocast(): outputs model(images) loss criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()4.3 常见问题排查问题1训练初期loss不下降检查学习率是否合适验证数据预处理是否正确确认模型参数初始化问题2验证集准确率波动大增加batch size尝试更强的正则化检查学习率调度策略5. 模型评估与可视化5.1 注意力图可视化理解ViT如何看图像def visualize_attention(model, image): model.eval() with torch.no_grad(): outputs model(image, output_attentionsTrue) attentions outputs.attentions[-1] # 最后一层注意力 avg_attention attentions.mean(dim1)[:, 0, 1:] # class token对其他patch的注意力 # 重塑为2D网格 grid_size int(math.sqrt(avg_attention.shape[-1])) attention_map avg_attention.reshape(-1, grid_size, grid_size) # 叠加到原图 plt.imshow(image) plt.imshow(attention_map, cmaphot, alpha0.5)5.2 不同配置对比ViT性能与配置密切相关配置项BaseLargeHuge层数122432隐藏层维度76810241280MLP维度307240965120头数121616参数量86M307M632M6. 进阶技巧与优化6.1 知识蒸馏小模型可通过蒸馏从大ViT学习class DistillationLoss(nn.Module): def __init__(self, alpha0.5): super().__init__() self.alpha alpha self.ce_loss nn.CrossEntropyLoss() self.kl_loss nn.KLDivLoss(reductionbatchmean) def forward(self, student_logits, teacher_logits, labels): ce_loss self.ce_loss(student_logits, labels) kl_loss self.kl_loss( F.log_softmax(student_logits/T, dim-1), F.softmax(teacher_logits/T, dim-1) ) * (T**2) return self.alpha * ce_loss (1 - self.alpha) * kl_loss6.2 混合架构CNN与ViT的混合方案往往能取得更好效果class HybridViT(nn.Module): def __init__(self): super().__init__() self.cnn_backbone torchvision.models.resnet50(pretrainedTrue) self.vit VisionTransformer( hidden_size768, num_hidden_layers12, num_attention_heads12 ) def forward(self, x): features self.cnn_backbone(x) return self.vit(features)在实际项目中ViT的实现需要根据具体任务调整。一个完整的ViT模型通常需要数百行代码但核心思想就是这些关键组件的组合。建议从Base配置开始逐步扩展到更复杂的变体。