016、HAT混合注意力:通道-空间双注意力与Transformer的协同设计

📅 2026/7/1 14:33:04
016、HAT混合注意力:通道-空间双注意力与Transformer的协同设计
016、HAT混合注意力通道-空间双注意力与Transformer的协同设计从一次“模型不收敛”的深夜调试说起去年年底我在处理一个遥感图像超分任务时遇到了一个让人抓狂的问题用SwinIR训练了三天PSNR死活卡在29.5dB上不去。检查了数据增强、学习率调度、甚至把初始化方式换成了Kaiming结果纹丝不动。后来在TensorBoard里盯着特征图可视化看了两个小时发现模型在纹理密集区域比如建筑物边缘和均匀区域比如水面的响应几乎一样——说白了注意力机制根本没学会“该看哪里”。这个现象让我意识到Transformer虽然擅长建模长距离依赖但它在局部细节的“选择性关注”上其实不如传统的通道注意力比如SENet和空间注意力比如CBAM来得直接。于是我开始思考能不能把这两套思路拧在一起让模型既拥有Transformer的全局视野又保留CNN注意力的局部锐利这就是HATHybrid Attention Transformer的出发点。它不是简单的“加个模块”而是重新设计了注意力机制的协作方式。通道注意力别小看这个“老古董”很多人觉得通道注意力已经过时了但在超分任务里它依然是“性价比之王”。HAT里的通道注意力模块本质上是对SENet的改进——但有个关键细节别在通道压缩时用ReLU。classChannelAttention(nn.Module):def__init__(self,dim,reduction16):super().__init__()# 这里踩过坑用ReLU会导致高频信息丢失换成GELU会好很多self.fcnn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(dim,dim//reduction,1,biasFalse),nn.GELU(),# 别这样写nn.ReLU(inplaceTrue)nn.Conv2d(dim//reduction,dim,1,biasFalse),nn.Sigmoid())defforward(self,x):# 注意这里要保留原始特征图的维度别直接乘scaleself.fc(x)returnx*scale为什么不用ReLU因为超分任务需要保留精细的纹理信息ReLU的硬零截断会抹掉弱响应但重要的通道。GELU的软饱和特性让模型能保留更多“犹豫不决”的特征——这在重建高频细节时特别重要。空间注意力别做成“全局平均”空间注意力最容易犯的错误是做成全局平均池化。HAT的做法是用可变形卷积替代固定感受野。classSpatialAttention(nn.Module):def__init__(self,dim):super().__init__()# 这里用3x3可变形卷积别用普通卷积self.offset_convnn.Conv2d(dim,18,3,padding1)# 9个点的偏移量self.mod_convnn.Conv2d(dim,9,3,padding1)# 9个点的调制权重self.dcnDeformConv2d(dim,dim,3,padding1)self.sigmoidnn.Sigmoid()defforward(self,x):offsetself.offset_conv(x)modtorch.sigmoid(self.mod_conv(x))# 别这样写直接用固定网格采样attnself.dcn(x,offset,mod)returnx*self.sigmoid(attn)可变形卷积让空间注意力能自适应地聚焦到纹理边缘而不是均匀地扫描整个区域。这在处理遥感图像里的不规则建筑物、医学图像里的细胞边界时效果立竿见影。Transformer部分别照搬Swin很多人直接把Swin Transformer搬过来用但HAT做了两个关键改动第一窗口划分策略。标准Swin用固定7x7窗口但HAT根据特征图分辨率动态调整窗口大小。低分辨率层用大窗口比如12x12高分辨率层用小窗口比如4x4。这样既保证了全局感受野又避免了高分辨率下的计算爆炸。第二相对位置编码的初始化。这里有个坑直接用零初始化会导致训练初期梯度爆炸。HAT的做法是用截断正态分布初始化标准差设为0.02。classHATTransformerBlock(nn.Module):def__init__(self,dim,window_size7):super().__init__()self.window_sizewindow_size# 相对位置编码别用零初始化self.relative_position_biasnn.Parameter(torch.zeros((2*window_size-1)*(2*window_size-1),num_heads))# 这里踩过坑用nn.init.trunc_normal_替代nn.init.zeros_nn.init.trunc_normal_(self.relative_position_bias,std0.02)双注意力协同不是简单的相加HAT的核心创新在于通道-空间-Transformer的渐进式融合。具体来说第一阶段通道注意力先做“粗选”筛选出信息量大的通道第二阶段空间注意力做“精定位”在选中的通道上聚焦关键区域第三阶段Transformer做“全局整合”把局部特征拼成完整的高频信息这个顺序不能乱。我试过把空间注意力放在通道注意力前面结果PSNR掉了0.3dB——因为空间注意力在没有通道筛选的情况下会把计算资源浪费在噪声区域。classHATBlock(nn.Module):def__init__(self,dim):super().__init__()self.channel_attnChannelAttention(dim)self.spatial_attnSpatialAttention(dim)self.transformerHATTransformerBlock(dim)# 注意这里用残差连接别直接串联self.fusionnn.Conv2d(dim*3,dim,1)defforward(self,x):# 别这样写x self.transformer(self.spatial_attn(self.channel_attn(x)))cself.channel_attn(x)sself.spatial_attn(x)tself.transformer(x)# 三路并行后融合保留各自特性returnself.fusion(torch.cat([c,s,t],dim1))训练技巧别让注意力“偷懒”HAT在训练时有个常见问题通道注意力会逐渐“躺平”所有通道的权重趋近于1。解决办法是在损失函数里加入注意力正则项defhat_loss(pred,target,attn_weights):mse_lossnn.L1Loss()(pred,target)# 让注意力权重保持多样性别全变成1entropy_loss-torch.mean(attn_weights*torch.log(attn_weights1e-8))returnmse_loss0.01*entropy_loss这个技巧让模型在训练过程中始终保持对通道的“选择性关注”而不是偷懒地全盘接受。个人经验什么时候该用HAT如果你在以下场景中HAT值得一试纹理密集且不规则比如遥感图像、布料纹理、动物皮毛需要同时保留全局结构和局部细节比如人脸超分既要五官位置准确又要皮肤纹理真实训练数据量中等10万-100万张HAT的参数量比SwinIR大15%左右但收敛速度更快但如果你处理的是纯文本图像比如扫描文档或者医学CT图像结构相对固定标准SwinIR可能更合适——HAT的额外注意力模块反而会引入不必要的复杂度。最后说个玄学HAT在batch size16时效果最好batch size太大比如64会导致注意力分布过于平滑。这可能是小batch size带来的随机性有助于保持注意力的“锐度”。