混合注意力学习(1): 线性注意力 📅 2026/6/26 9:28:52 Prefill、Decode与KVCache在开始本文之前首先应该介绍一下什么是prefill什么是decode以及对应的KVCache。这样可以更好理解内存复杂度。小学生们都日益了解了现有的LLM大语言模型的组成成分主要是Transformer Block。其中注意力机制具有如下的计算公式在推理过程中我们采用 decoder-only 架构因此注意力机制为“掩码自注意力”。流程如下从上面的流程图中我们看到Key和Value是复用的即每一次新的计算都需要用到先前的输入生成的Key和Value。因此我们将其称为KVCache。我们可以发现每一个Transformer Block的KVCache都是随着序列长度线性增长的。因此整体的空间复杂度为 (假定)。因此我们的prefill和decode流程简化来说是如下所示的prefill要求多个seq并行进入进行计算decode则每次接收一个上次生成的token进行计算。重要的两个指标为TTFTPrefill开始到第一个token生成所需要的时间TPOT每个token生成之间的时间[1]。[1:1][1:2]同时我们也可以用roofline模型来刻画我们的序列长度请求次数导致的计算上限和存储上限。对于短序列prefill和decode阶段主要是存储受限计算强度低但是需要大量读写KVCache。对于长序列prefill阶段主要是计算受限计算强度高GEMM这样我们就介绍完了KVCachePrefill和Decode了解了他们的不同和受限情况。混合注意力架构混合注意力架构包括了稀疏注意力与线性注意力而线性注意力机制又起源于这篇Transformers are RNNs的文章。我们都大致介绍一下作为我们的基本背景。线性注意力这里也可以阅读苏神的文章[2][3]。我们总结了一张图如下Transformers 就是 RNNs[4]在上文中我们已经提到了经典的自注意力机制的计算我们这里再展示一次对于上述式子我们很容易想当然地采用近似函数来代替。我们首先分析我们的矩阵运算结果如下接下来我们就可以尝试用近似函数代替也就是 :这里距离我们的线性注意力还有一段距离但是已经不远了。我们回想在古老的分离向量机中为了实现非线性支持向量机我们需要使用到核函数技巧。为了将非线性问题转换成线性问题我们采用核函数技巧。对于输入空间 为特征空间希尔伯特空间如果存在映射函数 满足 , 则 为核函数 为映射函数。[5]在注意力机制中我们只需要要求函数非负来表示其概率性。对此可以参见[6]因此我们采用核函数技巧进行分解就可以得到如下的公式很容易注意到我们可以采用结合律来提取出公因式 。注意到 采用结合律需要转置。因此我们有这样我们很容易可以看出原先softmax的计算复杂度是 随着序列以 的平方复杂度增长。而核函数在维持原有隐维度的条件下保持 的线性复杂度增长并且隐维度 仍然具有可优化的空间。这就是线性注意力的由来。这样如果我们考虑到推理模式下的decode-only掩码自注意力场景下我们不会计算全部的序列 而是计算到当前序列 这样我们就有我们令 就可以得到很明显就是我们对应的最简单的RNN架构。这也说明了(1) Transformer其实是大号RNN(2)线性注意力在理论层面是可行的。大量的实验发现分母会导致严重的数值不稳定问题并且可以无需映射函数直接采用分子参与计算。[7]。这样实际上为 .我们最后再看看梯度的计算。注意在训练场景下我们是全序列因此我们在给定分子 和损失函数 的条件下参考[8]处的运算我们有这样我们就把所有的梯度都计算出来了。小贴士在参考原有的链式法则基础上适当通过各种转置方法保证梯度张量和参数张量保持一致因为梯度张量应该和参数张量相同这样才能利用梯度下降更新可以提升我们的计算速度。Fast Weight Programmers与DeltaNet[9]线性注意力允许我们使用如下的方式更新当前状态模仿RNN这样对于一个长为 的序列我们每次计算步数为 计算复杂度为 训练过程中需要的空间复杂度为 推理过程中需要的时间复杂度为 .并且在这篇论文[9:1]中 实际上是一个关联性的内存存储了当前瞬态从key到Value的映射。这样的更新可以看作是一个无上界的关联损失函数的梯度下降从而持续强化最近的键值对没有任何遗忘。也就是文章[7:1]中说的这样的持续无遗忘将会在长上下文中造成严重的干扰。我们的人脑也会通过忘记无关紧要的久远的非必要记忆来保证我们对当前上下文的专注。我的某位朋友指出我需要提供为什么梯度更新和上面的线性注意力更新是等价的。在此给出确切的证明。对于此[9:2] [10]提出了如下的更新方式也称为Delta Rule。新的 到来。Read取出上一次的 构造未更新前的前Key-我们看到Key和Value关联模式通过一个学习率网络 来构造动态学习率 是激励函数通过学习率控制K-V关联性。更新状态矩阵实现Write遗忘。这样也等价于重建了一个无上界的Loss重建成如下形式通过学习系数 来单步梯度下降修正自己当前时间步下的记忆关联 。这样的变换允许了硬件通过分块并行来提升计算速度。这篇文章中详细说明了针对线性注意力的高效并行[7:2]。证明如下然后通过 来更新我们的权重。更新的公式则为这就是DeltaNet提出的重建损失函数——来实现一定的遗忘性。 也被称为 Delta 系数。Gated DeltaNet[11]为了进一步减少历史记忆和状态对现有的影响Mamba2进一步通过权重衰减来实现对过去的遗忘更进一步结合Delta规则实现遗忘这样就等价于通过如下的损失函数进行梯度下降。因此 成为门控权重衰减系数(gating weight decay)这就是Gated Delta Network的具体形式。线性注意力并行机制[7:3]传统FWP形式的线性注意力并行线性注意力在时间迭代的形式如下我们对比一下线性注意力的并行形式与迭代形式。针对并行形式我们将 堆叠在一起形成一个整体这样就形成了 。注意堆叠的方式是 均采用如同 的堆叠方式。这样我们就拥有了如下的计算公式而最终计算结果需要要求查询不能看到未来的键和值不然就成透视未来了因此我们加入一个下三角掩码即可。这样最终输出就应该是我们比较一下并行算法和迭代算法的不同。在复杂度计算中我们假设 这样更直观一些。我们每一层Transformer Block的复杂度如下算法时间复杂度空间复杂度计算步数时间步迭代推理训练并行计算诶为什么并行计算方式的时间复杂度更高但是执行时间更低呢不要忘记并行计算的优势在于同时计算速度快瓶颈因素转移到了计算步数上。对于时间迭代算法我们无法充分发挥并行计算的优势。因此在长序列上很明显并行算法在计算步数上远小于迭代算法。在GPU上还可以充分利用tensorCore等用于GEMM的优势时间步迭代则不行。但是并行算法内存占用很高我们可以看到空间复杂度呈平方增长这又失去了线性注意力的优势。为了实现高效的计算分块并行就成为了一个权衡两者的利弊的一个有效方式这样可以充分利用计算资源的同时降低内存占用。FWP分块并行机制首先我们规定对应的符号。代表第 个分块。这个分块通过 的方式堆叠。代表 “第个分块中的第个列向量 ”。堆叠里面有个向量是因为我们还规定 .上面的记号对 都成立。代表第t个分块中的第 个状态矩阵。。这样我们就可以改写我们的迭代步骤成混合形式。块内的某一个元素 则为这样对整一个块我们有如下公式注意查询不能看到未来的key和value这样我们就实现了分块并行的线性注意力策略。对于每一层Transformer block计算步数则为 内存复杂度变为 计算复杂度则为 再次回到了线性注意力的计算复杂度空间复杂度上空间复杂度为 训练情况下则需要保存每一个chunk为也维持了线性的增长DeltaNet 分块并行机制DeltaNet 的更新公式如下其中我们有最直观的方式就是将后面的一部分重新表示成一个新的向量 并且将 吸收进去可以得到如下公式从而得到整体的计算公式。这样通过上一节的堆叠和掩码方式得到并行计算的方式也就是但是这样的表示真的正确吗我们存在如下的问题(1) 需要上一个状态的直接计算导致实际上无法针对 进行并行计算。(2) 计算每个 都需要上一个状态的状态矩阵 导致内存占用从 上升到 。重新定义 降低内存占用但是这同时也带来了新的问题我们的空间复杂度从 上升到了 。回顾我们的导出过程以及上面的公式我们可以得到这就意味着我们在计算矩阵 的时候总是需要保证我们至少存取了上一次的关联矩阵 这样我们的实际的内存复杂度就应该为 。对此我们需要重新定义 不再存储过去的状态矩阵。通过数学归纳法来得到新的 。回顾我们的导出过程我们有假设 . 这样我们有归纳起始条件. 如果有 则有这样 将不再需要读取上一次的状态矩阵 每次计算只需要存取 . 此时的内存复杂度再次回到曾经的 。实际上上面使用的数学归纳法的灵感来源于如下的矩阵计算相关HouseHolder变换和WY表示[12]HouseHolder 变换对于一个非零向量 ,如果一个矩阵 满足则这个矩阵称为 HouseHolder 矩阵。 称为 HouseHolder 向量。对于一个向量 , 称为HouseHolder变换。我们很容易发现 HouseHolder 变换是一个 rank-1 的修正。因为非零向量外积的秩永远为1所有的列都落在 中。WY表示假设 是一个 rank-r 的修正这样我们有 , .证明采用数学归纳法。我们假定 . 因此我们有这样很明显我们的WY表示是成立的。但是如同先前需要实现并行或者分块并行的理由相同——GPU等加速器更适合并行矩阵运算时间步数 的算法不适合在对应硬件上实现。因此我们需要实现DeltaNet的分块并行运算。DeltaNet的分块并行针对我们的DeltaNet的分块并行算法我们需要做一系列比较复杂的变换。我们首先将原本的公式表示成如下所示。我们需要充分利用广义householder变换的性质。因此我们定义如下因此我们重写 。通过循环迭代我们得到接着我们定义分块矩阵相关符号。代表第t个分块中的第 个状态矩阵并且。. 这对 同样适用。这就是我们原文的初始分块模式。但是存储 和 在训练/prefill过程中需要 的存储空间假设 我们可以通过类似上面的WY表示的数学归纳法来降低内存占用到 。我们需要分析一下 和 。 我们将他们展开可以得到