【从零实现注意力机制】手把手教你写一个带QKV的Attention模块

📅 2026/7/2 5:31:42
【从零实现注意力机制】手把手教你写一个带QKV的Attention模块
不抄Transformer我们自己造轮子彻底搞懂注意力机制的计算流程一、为什么还要自己实现注意力很多人学完Transformer后对Attention的理解还停留在论文公式层面。但面试一问QKV具体怎么计算的就支支吾吾说不清楚。关键问题注意力机制到底在做什么输入一个查询Q它想知道自己跟谁最相关K是标签V是内容通过Q和K的匹配度从V中提取信息最终Q被升级成带上下文信息的Q举个例子你查苹果这个词Q词典里有32个词条K每个词条有64维特征V。Attention算出你查的苹果跟这32个词条的相似度然后按照相似度加权组合这32个词条的特征得到一个增强版的苹果表示。这就是Attention的本质用查询去检索信息加权聚合出更有价值的表示。二、我们的任务已知V32个单词每个单词64维特征 →[1, 32, 64]K32个单词的索引标记 →[1, 1, 32]Q查询张量比如查我 →[1, 1, 32]要做什么计算Q与32个单词的相关性注意力权重分布用权重加权V得到一个增强版Q输出增强后的Q[1, 1, 32]和注意力权重[1, 32]三、类结构设计MyAttn ├── __init__(self, query_size, key_size, value_size1, value_size2, output_size) │ ├── self.attn: Linear(3232 → 32) # 计算注意力得分 │ └── self.attn_combine: Linear(3264 → 32) # 融合Q和V │ └── forward(self, Q, K, V) ├── 1. 计算注意力权重 │ ├── 拼接Q和K → [1, 64] │ ├── 线性层 → [1, 32] (得分) │ └── Softmax → [1, 32] (概率分布) ├── 2. 加权聚合V │ └── 权重 × V → [1, 1, 64] ├── 3. 融合Q和聚合后的V │ ├── 拼接 → [1, 1, 96] │ └── 线性层 → [1, 1, 32] └── 4. 返回输出和注意力权重四、完整代码逐行解析import torch import torch.nn as nn import torch.nn.functional as F class MyAttn(nn.Module): def __init__(self, query_size, key_size, value_size1, value_size2, output_size): super().__init__() # 记录维度参数方便调试和扩展 self.query_size query_size self.key_size key_size self.value_size1 value_size1 # 序列长度单词数 self.value_size2 value_size2 # 词向量维度 self.output_size output_size # 【核心1】注意力得分计算层 # 输入Q拼接K (323264维) # 输出32维每个单词一个得分 self.attn nn.Linear(query_size key_size, value_size1) # 【核心2】注意力融合层 # 输入Q拼接加权后的V (326496维) # 输出32维最终表示 self.attn_combine nn.Linear(query_size value_size2, output_size)为什么设计两个线性层attn负责相关性打分输入Q和K输出跟V序列长度一致的得分向量attn_combine负责融合升级把原始Q和从V中提取的信息合并得到增强版Qdef forward(self, Q, K, V): # 阶段1计算注意力权重 # 为什么要用Q[0]因为假设batch_size1取出具体数据方便计算 # Q[0]: [1, 32], K[0]: [1, 32] → cat → [1, 64] qk_cat torch.cat((Q[0], K[0]), dim-1) # 线性变换得到得分每个单词一个分数 # [1, 64] → [1, 32] attn_scores self.attn(qk_cat) # Softmax归一化为概率分布 # [1, 32] → [1, 32] (所有概率和为1) attn_weights F.softmax(attn_scores, dim-1)Softmax的作用将原始得分可正可负转换为概率分布0-1之间和为1让相关的单词权重高不相关的权重低dim-1表示在最后一个维度32个单词上做归一化# 阶段2用权重加权聚合V # 扩展维度准备做矩阵乘法 # [1, 32] → [1, 1, 32] attn_weights_expanded attn_weights.unsqueeze(0) # 批量矩阵乘法权重 × V # [1, 1, 32] [1, 32, 64] [1, 1, 64] attn_applied torch.bmm(attn_weights_expanded, V)bmm的作用bmm batch matrix multiplication批量矩阵乘法权重矩阵[1, 1, 32]× V矩阵[1, 32, 64]本质用32个权重系数对32个64维向量做加权求和结果一个64维的加权平均向量# 阶段3融合Q和提取的信息 # 拼接原始Q 从V中提取的信息 # [1, 1, 32] [1, 1, 64] [1, 1, 96] output_cat torch.cat((Q, attn_applied), dim-1) # 降维融合后映射到目标维度 # [1, 1, 96] → [1, 1, 32] output self.attn_combine(output_cat) return output, attn_weights为什么要融合Q和attn_appliedattn_applied是从V中提取的外部信息原始Q包含自身信息融合两者得到自我上下文的增强表示类似于ResNet的残差思想输出 Q 从V中提取的信息五、测试代码if __name__ __main__: # 维度配置 query_size, key_size, value_size1, value_size2, output_size 32, 32, 32, 64, 32 # 构造输入 Q torch.randn(1, 1, query_size) # [1, 1, 32] K torch.randn(1, 1, key_size) # [1, 1, 32] V torch.randn(1, value_size1, value_size2) # [1, 32, 64] # 前向传播 my_attn MyAttn(query_size, key_size, value_size1, value_size2, output_size) output, attn_weights my_attn(Q, K, V) print(f输出形状: {output.shape}) # [1, 1, 32] print(f注意力权重形状: {attn_weights.shape}) # [1, 32]六、维度变化速查表步骤操作输入形状输出形状说明1拼接Q,KQ:[1,32], K:[1,32][1,64]准备打分2线性层attn[1,64][1,32]每个单词一个得分3Softmax[1,32][1,32]归一化为概率4unsqueeze[1,32][1,1,32]匹配batch维度5bmm[1,1,32] × [1,32,64][1,1,64]加权聚合V6拼接Q[1,1,32] [1,1,64][1,1,96]融合信息7线性层combine[1,1,96][1,1,32]降维输出七、核心问题QAQ1为什么用Q[0]而不是直接用QA因为假设batch_size1Q[0]取出第一维数据从[1,1,32]变成[1,32]去掉冗余的batch维度让计算更直观。但实际项目中建议直接用cat((Q,K), dim-1)保持通用性。Q2attn和attn_combine的区别Aattn计算注意力权重输入Q和K输出32维权重向量attn_combine融合Q 加权后的V输出最终的增强表示Q3为什么要用bmmAbmm专门用于批量矩阵乘法高效且自动处理batch维度。这里用权重向量[1,1,32]去加权V的32个词向量[1,32,64]得到加权和[1,1,64]。Q4注意力权重为什么要用SoftmaxASoftmax将得分映射为概率分布确保所有权重都在0-1之间且和为1。这样相关的单词获得高权重不相关的获得低权重符合注意力机制聚焦重要信息的思想。八、总结注意力机制的三步走打分用Q和K计算相关性得分加权Softmax归一化为权重加权聚合V融合将聚合结果与原始Q融合得到增强表示代码设计的核心思想用线性层attn学习如何打分用线性层attn_combine学习如何融合所有参数通过反向传播自动学习什么时候用Attention需要从大量信息中提取关键内容时需要建模长距离依赖时需要动态分配计算资源时