从零手写注意力机制:可调试的QKV计算与数值稳定性实践

📅 2026/6/25 20:01:51
从零手写注意力机制:可调试的QKV计算与数值稳定性实践
1. 项目概述这不是在讲“注意力”而是在造一个注意力机制的发动机“Demystifying Attention: Building It from the Ground Up”——这个标题一出来我就知道它不是那种泛泛而谈“注意力有多重要”的鸡汤文也不是调用几行 PyTorch 就完事的速成教程。它直指当前大模型底层最核心的那块“心脏”注意力机制Attention Mechanism。关键词里没写“Transformer”“LLM”“PyTorch”但它们全都在背景里站着没提“softmax”“QKV”“masking”可这些就是你要亲手拧紧的每一颗螺丝。我带过不少刚从机器学习入门转到NLP方向的工程师他们常卡在一个认知断层上读论文时看到“multi-head self-attention”像看天书跑通 Hugging Face 的 pipeline 却完全不知道中间矩阵乘了几次、维度怎么变、梯度往哪流。这个项目就是专为填平这个断层而生的——它要求你从零手写一个可调试、可打断点、可逐层打印 shape 和数值的注意力模块不依赖任何高级封装连torch.nn.Linear都得自己用torch.randn初始化权重再做矩阵乘。我试过三次完整复现第一次用 NumPy纯靠纸笔推导前向传播的每一步张量形状变化卡在batch_size × seq_len × d_model到batch_size × num_heads × seq_len × head_dim的 reshape 逻辑上整整两天第二次用 PyTorch 写出骨架结果反向传播时发现attn_weights的梯度在 softmax 后被截断查了六小时才发现是torch.no_grad()没关第三次才真正跑通带 mask 的 causal attention并把每个中间变量——从Q K.T / sqrt(d_k)的原始分值到 softmax 后的概率分布再到attn_output attn_weights V的加权和——全部打印出来一行行对齐论文公式。这种“笨功夫”带来的收益是立竿见影的当你亲手算出Q[0, 0, :] K[0, 0, :].T等于 23.7而softmax后对应位置变成 0.0012你就再也不会把“注意力得分高”误解为“两个词语义相似”而是清醒意识到这是模型在当前 token 位置对历史所有位置计算出的一个动态权重分配器它的输出不是判断而是路由信号。适合谁不是给想快速上线业务模型的产品经理而是给那些愿意花三天时间只为搞懂scaled_dot_product_attention里那个sqrt(d_k)为什么非得是根号、而不是除以 2 或者 log 的算法工程师、研究型学生以及被面试官问“self-attention 的梯度怎么回传”时当场愣住的求职者。它解决的不是“能不能用”而是“为什么这么用”“错一点会怎样”“改一个参数整个链条怎么崩”。2. 核心设计思路为什么必须“从地面建起”而不是站在巨人的肩膀上2.1 拒绝黑箱从matmul开始而非nn.MultiheadAttention很多教程一上来就调用torch.nn.MultiheadAttention然后告诉你“设置embed_dim512, num_heads8就行”。这就像教人修车直接递给你一台装好发动机的整车说“踩油门它就跑”。但当你发现车在高速时抖动你根本无从下手——是火花塞老化正时皮带松动还是曲轴动平衡出了问题注意力机制同理。nn.MultiheadAttention是个高度优化的工业级组件它内部做了 fused kernel、memory-efficient attention、flash attention 适配甚至自动处理了is_causalTrue时的下三角 mask。这些优化对生产环境至关重要但对理解原理却是障碍。我们选择“从地面建起”核心逻辑有三层第一层是控制变量。当你手写Q K.T时你可以强制让Q和K全为 1立刻看到输出矩阵全是d_k因为1×11×1...共d_k次从而验证维度计算是否正确而用封装接口你永远看不到这个中间态。第二层是暴露缺陷。比如softmax在Q K.T值域极大时会溢出exp(1000)直接变inf手写实现会让你第一时间撞上torch.finfo(torch.float32).max ≈ 3.4e38这堵墙进而逼你实现logsumexp稳定化而封装版早已内置了clamp或log_softmax你只看到结果正常却不知背后有多少防御性代码。第三层是建立直觉。我让学生对比两种写法一种是attn_weights torch.softmax(Q K.T / math.sqrt(d_k), dim-1)另一种是先算scores Q K.T / math.sqrt(d_k)再attn_weights torch.softmax(scores, dim-1)。看似一样但前者在调试时无法 inspectscores后者却能清晰看到“原始分值”如何被缩放、如何被指数化、如何被归一化。这种可观察性是构建工程直觉的基石。提示不要跳过math.sqrt(d_k)这个缩放因子。它不是魔法数字——当Q和K的元素服从均值为 0、方差为 1 的分布时Q K.T的方差会随d_k线性增长因为Var(XY) E[X²]E[Y²] - (E[X]E[Y])² ≈ 1×1 1但矩阵乘是d_k项求和所以总方差≈d_k。如果不缩放softmax的输入会越来越大导致梯度消失exp(large)让小值趋近 0梯度几乎为 0。这就是为什么d_k64时1/sqrt(64)0.125是个关键调节阀。2.2 分阶段演进从单头到多头从无 mask 到 causal mask我们不追求一步到位写出工业级代码而是按认知负荷递进。第一阶段单头、无 mask、无 dropout。目标只有一个让forward()输出和torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causalFalse)完全一致torch.allclose误差 1e-5。这迫使你精确处理Q, K, V的 batch 维度对齐、seq_len轴的点积广播规则、softmax的dim参数指定。第二阶段加入 causal mask。这时你必须手动构造一个下三角矩阵torch.tril(torch.ones(seq_len, seq_len))并把它 broadcast 到(batch, heads, seq_len, seq_len)形状再用-1e9 * (1 - mask)做 masked_fill。这里有个经典陷阱mask 必须是float类型且1e9要足够大才能让softmax后的对应位置接近 0exp(-1e9)几乎为 0。第三阶段拆解 multi-head。重点不是“复制粘贴 8 次”而是理解Linear层如何将d_model映射到d_model × 3Q/K/V 各占一份再view成(batch, seq_len, num_heads, head_dim)最后transpose(1, 2)把seq_len和num_heads轴交换——这个 transpose 是为了后续matmul能对每个 head 并行计算。很多人在这里混淆permute和transpose结果Q的 shape 变成(batch, num_heads, head_dim, seq_len)导致点积维度不匹配。实测下来用einops.rearrange(x, b s (h d) - b h s d, hnum_heads)比原生viewtranspose更不易出错。2.3 工具链极简主义只用torch和numpy拒绝一切高级抽象项目明确禁用transformers、xformers、甚至torch.nn.TransformerEncoderLayer。理由很实在这些库的源码本身就是“谜题”。比如xformers的memory_efficient_attention用了 Triton 内核你根本没法pdb.set_trace()transformers的Attention类里混着past_key_values缓存、attention_mask多种格式兼容、output_attentions开关逻辑分支太多。我们只要最干净的信号路径输入x→Linear→Q/K/V→matmul→softmax→matmul→Linear→ 输出。连dropout都放在最后一步而不是插在attn_weights后——因为论文原始实现中dropout 是作用在attn_output上的Dropout(attn_output)而非attn_weights。这个细节差异会导致训练稳定性不同attn_weightsdropout 会让某些 token 完全被忽略而attn_outputdropout 是对最终加权和做随机丢弃更符合“特征层面正则化”的直觉。我自己就踩过这个坑早期把 dropout 加在 softmax 后模型在长序列上 loss 突然飙升debug 三天才发现是attn_weights的稀疏化破坏了信息路由的连续性。3. 核心细节解析手写注意力的 7 个生死关卡与避坑指南3.1 关卡一Q/K/V 的初始化与维度对齐——别让第一行代码就报错手写注意力的第一行往往是Q self.W_q(x)。这里藏着三个致命细节。第一W_q的权重初始化不能用默认的torch.nn.Linear初始化kaiming_uniform_而必须用torch.nn.init.xavier_normal_。为什么因为xavier_normal的标准差是1/sqrt(fan_in)能保证Q W_q x的输出方差接近 1与K, V保持量级一致若用kaiming针对 ReLU 设计Q的方差会偏大导致Q K.T的值域爆炸。第二x的输入 shape 必须是(batch_size, seq_len, d_model)但很多初学者从nn.Embedding拿到的是(seq_len, batch_size, d_model)PyTorch 默认的batch_firstFalse。如果你没调用.transpose(0, 1)Q K.T会因seq_len和batch_size维度错位而报matmul: expected 2D tensor。第三W_q, W_k, W_v的in_features必须严格等于d_modelout_features必须等于d_model不是d_model // num_heads。多头的拆分是在Linear输出后做的不是在权重维度上切分的。我见过最离谱的错误是把W_q定义成nn.Linear(d_model, d_model // num_heads)结果Q的最后一个维度只有 64而K是 512matmul直接崩溃。解决方案统一用nn.Linear(d_model, d_model)然后通过view(..., num_heads, head_dim)拆分。注意head_dim d_model // num_heads必须整除。如果d_model512,num_heads6512//685.333程序不会报错但view时会因总元素数不匹配而RuntimeError。务必在__init__中加断言assert d_model % num_heads 0, fd_model {d_model} not divisible by num_heads {num_heads}。3.2 关卡二Q K.T的广播与缩放——那个sqrt(d_k)不是装饰品Q K.T看似简单实则暗流汹涌。假设Q.shape (batch, num_heads, seq_len, head_dim)K.shape (batch, num_heads, seq_len, head_dim)那么Q K.T的结果 shape 是(batch, num_heads, seq_len, seq_len)。这里K.T不是简单的K.transpose(-2, -1)而是K.permute(0, 1, 3, 2)因为K是 4D 张量T只对最后两维生效。如果误用K.transpose(-1, -2)在head_dim64时可能侥幸成功但一旦seq_len ≠ head_dim就会因维度不匹配而失败。更隐蔽的问题是缩放。scores torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(head_dim)这行代码math.sqrt(head_dim)必须是float不能是int。在 Python 3.8 中int除法会自动转float但为了保险显式写math.sqrt(float(head_dim))。我曾因head_dim64是int在某些旧版本 PyTorch 中触发RuntimeError: matmul: expected a floating point tensor。此外/ math.sqrt(head_dim)必须在matmul后立即执行不能等到softmax前再除——因为softmax对输入的绝对值敏感延迟缩放会导致数值不稳定。实测数据当head_dim64Q K.T的最大值达~1200exp(1200)直接溢出缩放后最大值约15exp(15)≈3.2e6完全在安全范围内。3.3 关卡三mask 的构造与应用——causal mask 不是画个三角形那么简单Causal mask 的目标是让位置i只能看到≤i的位置即attn_weights[i, j] 0当j i。最直观的做法是mask torch.tril(torch.ones(seq_len, seq_len))但这只是 2D 矩阵。实际需要的是 4D mask(batch, num_heads, seq_len, seq_len)。直接mask.unsqueeze(0).unsqueeze(0)会创建(1, 1, seq_len, seq_len)然后 broadcast 到 batch 和 heads 维度。但问题来了torch.tril返回的是float64而你的scores是float32混合运算会触发隐式类型转换警告。正确做法mask torch.tril(torch.ones(seq_len, seq_len, dtypetorch.bool))用bool类型避免精度问题再scores.masked_fill_(~mask, float(-inf))。注意是~mask取反因为mask是True表示允许的位置~mask才是需要屏蔽的位置。另一个常见错误是mask的seq_len和scores的seq_len不一致——比如你在forward中用x.size(1)得到seq_len但mask是在__init__中预生成的固定尺寸。这会导致RuntimeError: The size of tensor a (128) must match the size of tensor b (64)。解决方案永远在forward中动态生成 mask或用torch.finfo(scores.dtype).min替代float(-inf)确保类型严格匹配。3.4 关卡四softmax的数值稳定性——logsumexp是你的救命稻草softmax(x) exp(x) / sum(exp(x))当x很大时exp(x)溢出当x很小时exp(x)下溢为 0。标准解法是softmax(x) exp(x - max(x)) / sum(exp(x - max(x)))即减去每行最大值。但手写时容易犯两个错一是只减scores.max()全局最大而不是scores.max(dim-1, keepdimTrue)[0]每行最大二是忘记keepdimTrue导致scores - max_val因维度不匹配而 broadcast 错误。更鲁棒的做法是直接调用torch.logsumexplog_probs scores - torch.logsumexp(scores, dim-1, keepdimTrue)然后attn_weights torch.exp(log_probs)。这样既避免了exp溢出又保证了sum(attn_weights, dim-1)严格等于 1浮点误差内。我做过对比实验用原始softmaxsum(attn_weights, dim-1)的最大偏差达1e-3用logsumexp版本偏差稳定在1e-7以内。这对长序列训练至关重要——偏差累积会导致梯度更新方向漂移。3.5 关卡五attn_weights V的维度缝合——transpose和view的战争attn_weights.shape (batch, num_heads, seq_len, seq_len)V.shape (batch, num_heads, seq_len, head_dim)那么attn_weights V的结果是(batch, num_heads, seq_len, head_dim)。接下来要缝合成(batch, seq_len, d_model)。这里有两条路一是attn_output attn_output.transpose(1, 2).contiguous().view(batch, seq_len, d_model)二是attn_output einops.rearrange(attn_output, b h s d - b s (h d))。前者需要contiguous()因为transpose返回的是 view而view要求内存连续后者无需担心。我推荐einops因为它的语义清晰“把h和d维度合并成一个”。但如果你坚持不用第三方库contiguous()是必选项。漏掉它view会报RuntimeError: view size is not compatible with input tensors size and stride。另外attn_output的Linear投影层W_o的in_features必须是d_model即num_heads × head_dimout_features也必须是d_model这样才能保持残差连接的维度一致。这里有个隐藏陷阱W_o的初始化同样要用xavier_normal_否则attn_output的方差会偏离 1影响后续 LayerNorm 的效果。3.6 关卡六残差连接与 LayerNorm——别让 normalization 毁了你的梯度流Transformer 的残差连接是x attn_output但x的 shape 是(batch, seq_len, d_model)attn_output经过W_o后也是(batch, seq_len, d_model)看起来完美。然而attn_output的均值和方差往往与x不同——x经过 embedding 和 positional encoding均值接近 0方差接近 1而attn_output经过多次matmul和softmax其分布可能偏移。直接相加会导致x attn_output的方差变大后续LayerNorm的gamma和beta参数需要剧烈调整才能适应。标准解法是Pre-LN在MultiHeadAttention模块前先对x做LayerNorm再送入Q/K/V计算残差后不再LayerNorm。原始论文用的是Post-LN残差后LayerNorm但 Pre-LN 训练更稳定。手写时LayerNorm的normalized_shape必须是(d_model,)而不是(seq_len, d_model)或(batch, seq_len, d_model)。nn.LayerNorm(d_model)会对最后的d_model维度做归一化即对每个 token 的d_model维向量独立归一化。如果误设nn.LayerNorm((seq_len, d_model))会试图对(seq_len, d_model)这个二维 shape 归一化直接报错。实操心得在forward中打印x.mean(), x.std()和attn_output.mean(), attn_output.std()如果两者标准差相差超过 2 倍就要检查W_q/W_k/W_v/W_o的初始化是否一致。3.7 关卡七dropout 的位置与模式——训练时开推理时关但别关错了Dropout在注意力模块中有两个位置可选attn_weights后或attn_output后。原始论文和torch.nn.MultiheadAttention都采用后者attn_output self.dropout(attn_output)。原因在于attn_weights是概率分布对其 dropout 会破坏 softmax 的归一化性质sum(dropout(attn_weights), dim-1)不再是 1导致attn_output的期望值偏移。而attn_output是特征向量对其 dropout 是标准的正则化。手写时self.dropout nn.Dropout(dropout_p)必须在__init__中定义并在forward中调用self.dropout(attn_output)。关键细节nn.Dropout在trainingTrue时随机置 0在trainingFalse时不做任何操作即x * 1。但很多人忘记在eval()模式下调用model.eval()导致推理时仍在 dropout输出波动巨大。更隐蔽的错误是self.dropout被定义在MultiHeadAttention类里但forward中调用的是self.dropout(attn_output)而attn_output是float32dropout期望float32没问题但如果attn_output是float16混合精度训练dropout会报错需用torch.nn.Dropout1d或手动实现。我的经验是始终用float32训练注意力模块等整个模型稳定后再引入 AMP。4. 实操过程从零开始构建可调试注意力模块的完整流水线4.1 环境准备与最小可运行骨架我们从最简骨架开始不追求功能完整只确保能跑通。创建文件attention_from_scratch.py内容如下import torch import torch.nn as nn import math class ScaledDotProductAttention(nn.Module): def __init__(self, dropout_p0.0): super().__init__() self.dropout nn.Dropout(dropout_p) def forward(self, Q, K, V, maskNone): # Q, K, V: (batch, num_heads, seq_len, head_dim) d_k Q.size(-1) # 计算 scores: (batch, num_heads, seq_len, seq_len) scores torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # 应用 mask if mask is not None: scores scores.masked_fill(mask 0, float(-inf)) # softmax 得到注意力权重 attn_weights torch.softmax(scores, dim-1) attn_weights self.dropout(attn_weights) # 加权求和 output torch.matmul(attn_weights, V) return output, attn_weights # 测试骨架 if __name__ __main__: batch, seq_len, num_heads, head_dim 2, 4, 2, 8 d_model num_heads * head_dim # 随机生成输入 Q torch.randn(batch, num_heads, seq_len, head_dim) K torch.randn(batch, num_heads, seq_len, head_dim) V torch.randn(batch, num_heads, seq_len, head_dim) # 创建 mask: causal mask for seq_len4 mask torch.tril(torch.ones(seq_len, seq_len)).bool() mask mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) attn ScaledDotProductAttention() output, weights attn(Q, K, V, mask) print(fQ shape: {Q.shape}) print(fOutput shape: {output.shape}) print(fAttn weights shape: {weights.shape}) print(fSum over last dim: {weights.sum(dim-1)})运行此脚本应输出Q shape: torch.Size([2, 2, 4, 8]) Output shape: torch.Size([2, 2, 4, 8]) Attn weights shape: torch.Size([2, 2, 4, 4]) Sum over last dim: tensor([[[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]])这个骨架验证了最核心的matmul、softmax、mask流程。注意mask的unsqueeze操作以及weights.sum(dim-1)必须全为 1这是 sanity check 的黄金标准。4.2 构建 MultiHeadAttention 类缝合 Q/K/V 与输出投影在骨架基础上扩展为完整的MultiHeadAttentionclass MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads, dropout_p0.0): super().__init__() assert d_model % num_heads 0, fd_model {d_model} not divisible by num_heads {num_heads} self.d_model d_model self.num_heads num_heads self.head_dim d_model // num_heads # Linear layers for Q, K, V self.W_q nn.Linear(d_model, d_model) self.W_k nn.Linear(d_model, d_model) self.W_v nn.Linear(d_model, d_model) self.W_o nn.Linear(d_model, d_model) self.attn ScaledDotProductAttention(dropout_p) self.dropout nn.Dropout(dropout_p) # 初始化权重 self._reset_parameters() def _reset_parameters(self): # 使用 xavier_normal 初始化所有 Linear 层 for p in self.parameters(): if p.dim() 1: nn.init.xavier_normal_(p) def forward(self, x, maskNone): # x: (batch, seq_len, d_model) batch, seq_len, _ x.shape # 生成 Q, K, V: (batch, seq_len, d_model) Q self.W_q(x) K self.W_k(x) V self.W_v(x) # 拆分为多头: (batch, seq_len, num_heads, head_dim) - (batch, num_heads, seq_len, head_dim) Q Q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K K.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V V.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # 如果有 mask扩展到多头维度 if mask is not None: # mask: (batch, 1, seq_len, seq_len) or (1, 1, seq_len, seq_len) mask mask.unsqueeze(1) # (batch, 1, seq_len, seq_len) - (batch, 1, seq_len, seq_len) # 计算注意力 attn_output, attn_weights self.attn(Q, K, V, mask) # 合并多头: (batch, num_heads, seq_len, head_dim) - (batch, seq_len, d_model) attn_output attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.d_model) # 输出投影 output self.W_o(attn_output) output self.dropout(output) return output, attn_weights测试代码追加# 测试 MultiHeadAttention mha MultiHeadAttention(d_model16, num_heads2, dropout_p0.1) x torch.randn(2, 4, 16) # (batch, seq_len, d_model) mask torch.tril(torch.ones(4, 4)).bool().unsqueeze(0).unsqueeze(0) # (1, 1, 4, 4) output, weights mha(x, mask) print(fMHA Output shape: {output.shape}) print(fMHA Weights shape: {weights.shape})此时output.shape应为(2, 4, 16)与输入x一致满足残差连接要求。4.3 添加 Pre-LayerNorm 与残差连接构建 Transformer Block真正的 Transformer Block 包含MultiHeadAttention和FeedForward两部分我们先完成前者class TransformerBlock(nn.Module): def __init__(self, d_model, num_heads, dropout_p0.0): super().__init__() self.norm1 nn.LayerNorm(d_model) self.attn MultiHeadAttention(d_model, num_heads, dropout_p) self.norm2 nn.LayerNorm(d_model) # FeedForward 留空专注注意力 self.dropout nn.Dropout(dropout_p) def forward(self, x, maskNone): # Pre-LN: 先 norm再 attn norm_x self.norm1(x) attn_output, attn_weights self.attn(norm_x, mask) # 残差连接 x x self.dropout(attn_output) # FFN 部分省略只保留 attn return x, attn_weights # 测试 Transformer Block block TransformerBlock(d_model16, num_heads2, dropout_p0.1) x torch.randn(2, 4, 16) mask torch.tril(torch.ones(4, 4)).bool().unsqueeze(0).unsqueeze(0) output, weights block(x, mask) print(fBlock Output shape: {output.shape})关键点self.norm1(x)在attn前调用x self.dropout(attn_output)是残差。此时output的mean和std应与输入x接近证明归一化有效。4.4 深度调试逐层打印中间变量定位数值异常调试的核心是“可视化”。在MultiHeadAttention.forward中插入打印def forward(self, x, maskNone): batch, seq_len, _ x.shape print(f[MHA] Input x: mean{x.mean():.4f}, std{x.std():.4f}, min{x.min():.4f}, max{x.max():.4f}) Q self.W_q(x) print(f[MHA] Q after W_q: mean{Q.mean():.4f}, std{Q.std():.4f}) Q Q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2) print(f[MHA] Q after reshape: shape{Q.shape}, mean{Q.mean():.4f}) # ... 后续步骤同理每步都 print运行时你会看到[MHA] Input x: mean0.0021, std0.9987, min-3.2145, max3.1876 [MHA] Q after W_q: mean0.0015, std0.9992 [MHA] Q after reshape: shapetorch.Size([2, 2, 4, 8]), mean0.0015如果某一步std突然变成 10 或 0.01就说明该层权重或计算有误。例如若Q after W_q的std10检查W_q初始化是否用了xavier_normal_若Q after reshape的mean偏离Q after W_q检查view是否改变了数据顺序。4.5 与 PyTorch 原生实现对齐用allclose验证正确性最终验证用torch.nn.functional.scaled_dot_product_attention作为黄金标准# 生成相同输入 torch.manual_seed(42) x torch.randn(2, 4, 16) mask torch.tril(torch.ones(4, 4)).bool().unsqueeze(0).unsqueeze(0) # 手写 MHA mha_custom MultiHeadAttention(d_model16, num_heads2, dropout_p0.0) mha_custom.eval() # 关闭 dropout with torch.no_grad(): out_custom, _ mha_custom(x, mask) # PyTorch 原生 Q_native mha_custom.W_q(x).view(2, 4, 2, 8).transpose(1, 2) K_native mha_custom.W_k(x).view(2, 4, 2, 8).transpose(1, 2) V_native mha_custom.W_v(x).view(2, 4, 2, 8).transpose(1, 2) out_native torch.nn.functional.scaled_dot_product_attention( Q_native, K_native, V_native, attn_maskmask, dropout_p0.0, is_causalFalse ) out_native out_native.transpose(1, 2).contiguous().view(2, 4, 16) out_native mha_custom.W_o(out_native) print(fCustom vs Native allclose: {torch.allclose(out