015、SwinIR窗口注意力:Swin Transformer在超分中的局部-全局建模

📅 2026/7/1 7:56:15
015、SwinIR窗口注意力:Swin Transformer在超分中的局部-全局建模
015、SwinIR窗口注意力Swin Transformer在超分中的局部-全局建模去年调一个视频超分模型跑了一周发现PSNR死活上不去比baseline还低0.3dB。排查到最后发现是窗口注意力里的mask没对齐——图像边界处窗口划分时多了一个像素的偏移导致边缘区域的信息完全错乱了。这种坑SwinIR的论文里不会写但实际部署时能让你debug到怀疑人生。今天这篇笔记咱们就聊聊SwinIR里那个让超分效果质变的窗口注意力机制。别把它当成普通的Transformer来看它本质上是在解决一个老问题局部纹理的保真度和全局结构的连贯性怎么在同一个网络里共存。窗口注意力到底在干什么传统SRCNN、EDSR这些CNN模型感受野受限于卷积核大小。你堆100层理论上感受野能覆盖全图但实际训练时梯度传播到后面就衰减了远处像素对当前像素的影响微乎其微。这就是为什么CNN超分出来的图大块纹理区域经常出现“糊成一团”的现象。ViTVision Transformer把整张图当成一个序列做自注意力理论上能捕获全局依赖。但问题来了一张256x256的图patch size8序列长度就是1024自注意力的计算复杂度是O(n²)显存直接爆炸。更致命的是超分任务需要保留精细的像素级信息全局注意力会把远处的噪声也加权进来反而破坏了局部纹理。SwinIR的窗口注意力就是在这两个极端之间找了个平衡点。它把特征图切成不重叠的窗口每个窗口内做自注意力。窗口大小通常是7x7这样每个窗口内的patch数量只有49计算量可控。但问题又来了窗口之间没有信息交流局部建模强了全局一致性怎么办移位窗口那个让代码多写100行的设计SwinIR的解决方案是交替使用两种窗口划分方式。第一层用常规的均匀网格切分第二层就把窗口整体向右下角偏移几个像素。这个偏移量通常是窗口大小的一半比如窗口7x7偏移量就是3。这里有个容易踩坑的地方偏移后的窗口边界处的窗口大小可能不是7x7了。比如图像尺寸是64x64窗口7x7偏移3个像素后右下角的窗口可能只有4x4。如果你直接对这些小窗口做自注意力padding会引入边界伪影影响重建质量。正确的做法是用cyclic shift——把图像循环移位让那些不完整的窗口拼接到左上角形成一个完整的窗口。但注意循环移位后原本不相邻的像素被强行拼到了一起自注意力计算时不能让他们互相看到。这就需要mask来屏蔽这些“假邻居”。我当初写mask的时候犯过一个低级错误把mask的维度搞反了。SwinIR的mask是加在softmax之前的形状应该是[B, num_heads, N, N]其中N是窗口内的patch数。我写成了[B, N, N, num_heads]结果模型训练了三天loss死活不降。后来逐层打印attention map才发现mask根本没生效所有位置都在做全连接。局部-全局建模的实战细节SwinIR里真正让超分效果提升的不是窗口注意力本身而是它和残差连接的配合方式。每个Swin Transformer Block包含两个子模块窗口多头自注意力W-MSA和移位窗口多头自注意力SW-MSA中间夹着一个MLP。每个子模块后面都跟着LayerNorm和残差连接。这个结构看起来和标准Transformer差不多但有个关键差异SwinIR在W-MSA和SW-MSA之间没有用任何下采样或上采样操作。这意味着特征图的分辨率在整个stage里保持不变。这对于超分任务很重要——你不想在重建阶段丢失任何像素级信息。实际调试时我发现一个反直觉的现象把窗口大小从7改成8PSNR反而下降了0.1dB。原因在于窗口大小必须是奇数这样偏移半个窗口后边界处的窗口才能对称。7x7窗口偏移3个像素左右边界各损失3个像素剩下的窗口大小是1x7或7x1这些细长窗口内的像素相关性很弱自注意力学不到有用信息。而8x8窗口偏移4个像素边界窗口变成0x8直接报错。代码实现里的那些坑写SwinIR的forward函数时有个细节特别容易忽略输入特征图的尺寸必须是窗口大小的整数倍。如果不是需要做padding。但padding的方式有讲究——不能用零填充因为零填充会在边界处引入暗边超分后的图像边缘会发黑。我试过几种padding策略反射填充reflect效果最好但计算量大复制填充replicate次之速度更快。最终我选了复制填充因为超分任务对速度有要求而且复制填充在边界处的伪影比零填充小得多。另一个坑是相对位置偏置relative position bias。SwinIR在计算注意力时加了一个可学习的相对位置编码形状是[(2window_size-1), (2window_size-1)]。这个偏置矩阵的索引计算很容易搞错。比如窗口大小是7相对位置范围是[-6, 6]共13个值。你需要把二维相对位置映射到一维索引公式是index (pos_y window_size-1) * (2*window_size-1) (pos_x window_size-1)。这个公式少一个括号整个偏置就全乱了。个人经验性建议如果你现在要复现SwinIR做超分我建议你从Swin Transformer的官方实现开始改而不是从零写。官方代码里对cyclic shift和mask的处理已经验证过了你只需要把分类头改成上采样模块就行。上采样模块的选择也有讲究。SwinIR原文用的是pixel shuffle亚像素卷积但我在实际测试中发现对于4倍超分先用pixel shuffle上采样2倍再用双线性插值上采样2倍效果比直接用pixel shuffle上采样4倍要好。原因是pixel shuffle在4倍时每个像素需要从16个通道中重组通道间的相关性很难学容易产生棋盘伪影。训练策略上别一上来就用大batch size。SwinIR的窗口注意力对batch size敏感batch size太大时不同样本的窗口划分不一致梯度更新方向会互相抵消。我一般先用batch size8训练100个epoch再增加到16微调。最后说一个玄学经验SwinIR的深度和宽度不是越深越好。我试过把RSTBResidual Swin Transformer Block从6个增加到12个PSNR反而下降了0.05dB。原因是深层网络的梯度回传路径太长窗口注意力的局部性导致梯度在传播过程中丢失了全局信息。对于超分任务6个RSTB已经足够再深就是过拟合了。调试SwinIR的过程本质上是在局部细节和全局结构之间找平衡。窗口大小、偏移量、深度、宽度每个参数都在影响这个平衡。没有通用的最优配置只有针对你数据集的最优配置。多跑几组消融实验比看论文里的ablation study更有用。