从 Flash Attention 到 Speculative Decoding:大模型推理加速最全解读

📅 2026/7/2 3:56:37
从 Flash Attention 到 Speculative Decoding:大模型推理加速最全解读
用过 ChatGPT 的人都知道模型生成回答需要等几秒甚至几十秒。这背后的瓶颈在哪有没有办法让大模型说快一点本文从底层原理出发拆解 6 大类加速方案帮你建立起完整的推理加速知识框架。01 一个核心问题大模型生成为什么慢要理解加速方法先得理解大模型到底是怎么生成内容的。大模型生成文本的过程本质上是一个逐字接龙游戏模型先看到你输入的 Prompt这叫Prefill阶段一次处理所有输入 Token然后一个词一个词地往外蹦这叫Decode阶段每次生成一个 Token那每一步具体怎么算的呢这就要说到 Transformer 最核心的模块——Self-Attention。简单来说每个 Token 都会生成三个向量Query查询、Key键、Value值。生成下一个 Token 时它的 Query 会跟前面所有 Token 的 Key 算相似度dot product然后用这个相似度去加权汇总所有的 Value。你看问题就出在这里每生成一个 Token都要跟前面所有的 Token 做一次注意力计算。Token 越多计算量就越大等待时间就越长。那怎么加速呢学术界和工业界想出了一系列方法而且它们可以组合使用。下面我们逐一拆解。02 Flash Attention不改变算法只改变数据搬运方式论文地址https://arxiv.org/abs/2205.14135一句话理解Flash Attention 没有改变 Attention 的计算公式它只是改变了计算的顺序让数据少搬几次家从而大幅提速。先搞懂 GPU 的工作方式要理解 Flash Attention 为什么快得先知道 GPU 是怎么干活的GPU 有很多Execution Unit执行单元它们负责算数每个执行单元有一个很小的工作台——SRAM速度极快但容量极小大量的数据存在仓库里——HBM高带宽内存容量大但读写慢所以每次运算时执行单元得先把数据从 HBM 搬到 SRAM算完再搬回去。Attention 计算的瓶颈不是算得慢是搬数据搬得慢。传统 Attention 计算分好几步每步都要读写 HBM来回搬很多次搬 Q、K → 算 dot product → 搬结果回 HBM找最大值 → 搬回 HBM算分母Softmax 求和→ 搬回 HBM算最终 attention weight → 搬回 HBM搬 V 进来做 weighted sum → 搬结果回 HBM整个过程HBM 被反复读写。Flash Attention 的核心思路Flash Attention 把上面这些步骤合并了——不需要先算出完整的 attention weight再去做 weighted sum。它在 SRAM 上边算边合并一步到位。具体来说它把 K 和 V 切成多个 Chunk每次只处理一个 Chunk关键技巧是处理下一个 Chunk 时如果发现前面 Chunk 用的最大值小了可以用一个修正项调整前面已经算好的结果不需要重算。用公式来表达就是当处理第 2 个 Chunk 时之前算的o1o_1o1​需要调整o2o1s1s2(ed1−d2)∑iN12Neai−d2s2vi o_2 o_1 \frac{s_1}{s_2}(e^{d_1 - d_2}) \sum^{2N}_{iN1}{\frac{e^{a_i - d_2}}{s_2}v_i}o2​o1​s2​s1​​(ed1​−d2​)iN1∑2N​s2​eai​−d2​​vi​这个公式看起来很复杂但核心思想很简单前面算的部分乘以一个修正因子即可不需要重算。效果与局限✅不改变结果和标准 Attention 数学上完全等价✅即插即用可以直接用在任何用了 Attention 的模型上✅显著加速Sequence 越长效果越明显⚠️ 如果 Sequence 太短加速效果不明显前面处理本身时间就短03 KV Cache用存储换速度核心思想在 Decode 阶段每生成一个 Token模型都要计算它跟前面所有 Token 的 Attention。但前面的 Token 不会变所以它们的 K 和 V 也不需要重新算。KV Cache 就是把前面 Token 的 K 和 V 存起来每次只算新 Token 的 QKV避免重复计算。这个思路简单直接而且完全不改变 Attention 的计算逻辑。代价内存爆炸KV Cache 的问题也很明显——它太吃内存了。每次输入一个新 Token就要多存一组 K 和 V。而且这还不是一组——Transformer 有多层 × 多头的 K 和 V。拿 Gemma 2 来算笔账每个Token: 46(层)×32(头)×128(维度)×2(FP16)×2(V, K)753664 字节(约736KB) \text{每个Token: } 46(\text{层}) \times 32 (\text{头}) \times 128 (\text{维度}) \times 2 (\text{FP16}) \times 2(\text{V, K}) 753664 \text{ 字节} (\text{约736KB})每个Token:46(层)×32(头)×128(维度)×2(FP16)×2(V, K)753664字节(约736KB)注意这是一个 Token需要的空间。如果 Sequence 长度是 114k那 A100 的 80GB 显存就刚好被填满。KV Cache 让模型变快但它会吃掉大量显存限制了能处理的上下文长度。这是 KV Cache 加速方案的核心矛盾。04 减少 KV 存储三招让 Cache 更省空间既然 KV Cache 太占地方能不能让 K 和 V 少存一点有三个经典方案方案一Multi-Query AttentionMQA多头注意力MHA里每个头都有自己的 K 和 V。MQA 的想法是多个 Query 头共享一组 K 和 V。好处是 KV Cache 大幅减少但问题也很明显共享一组 K/V 太粗暴了模型表现会下降。方案二Grouped-Query AttentionGQAGQA 是一个折中方案把 Query 头分成几组每组共享一组 K 和 V。它介于 MHA 和 MQA 之间在效率和效果之间取了一个平衡。现在很多新模型如 Llama 2/3都用的 GQA。方案三Multi-head Latent AttentionMLA论文地址https://arxiv.org/abs/2405.04434 DeepSeekMLA 的想法更巧妙不直接存 K 和 V而是先把它们压缩成一个低维向量再存用的时候也不一定要解压缩。这样做有两个关键技巧技巧 1在压缩空间里做 dot product把输入 X 压缩成向量ccc存进仓库的就是这个ccc。需要算 Attention 时aq⋅kqTkqTWkc(WkTq)Tc(WkTq)⋅c a q \cdot k q^T k q^T W_k c (W_k^T q)^T c (W_k^T q) \cdot caq⋅kqTkqTWk​c(WkT​q)Tc(WkT​q)⋅c你看不需要先把ccc解压成kkk直接把qqq转一下就能在压缩空间里做 dot product。技巧 2在压缩空间里做 Weighted Sum算完 attention weight 后要做 weighted sum 得到输出oa^1v1a^2v2a^3v3a^4v4a^1Wvc1a^2Wvc2a^3Wvc3a^4Wvc4Wv(a^1c1a^2c2a^3c3a^4c4) \begin{align*} o \hat{a}_1 v_1 \hat{a}_2 v_2 \hat{a}_3 v_3 \hat{a}_4 v_4 \\ \hat{a}_1 W_v c_1 \hat{a}_2 W_v c_2 \hat{a}_3 W_v c_3 \hat{a}_4 W_v c_4 \\ W_v (\hat{a}_1 c_1 \hat{a}_2 c_2 \hat{a}_3 c_3 \hat{a}_4 c_4) \end{align*}o​a^1​v1​a^2​v2​a^3​v3​a^4​v4​a^1​Wv​c1​a^2​Wv​c2​a^3​Wv​c3​a^4​Wv​c4​Wv​(a^1​c1​a^2​c2​a^3​c3​a^4​c4​)​核心洞察先在压缩的ccc上做 weighted sum最后只解压缩一次。这大大减少了计算量。MLA 是一个需要重新训练模型的方法但它被 DeepSeek 等前沿模型采用证明了这条路是可行的。05 Sliding Window Attention Streaming LLM只看附近的内容Sliding Window Attention核心思路很简单每次做 Attention 时不需要看整个 Sequence只看最近的 N 个 Token。但这样模型不就看不到长距离信息了吗有个巧妙的观察Transformer 层数越深Sliding Window Attention 能看到的范围实际上越大。因为第 1 层只看附近几个 Token第 2 层的输入就已经包含了第 1 层窗口里的信息相当于变相扩大了感受野。网络足够深的话即使每个窗口不大也能覆盖很长的范围。混合策略还有一种方案有些层用 Sliding Window有些层用全局 Attention。这样既能节省 KV Cache又能在关键层保持全局视野。Streaming LLM论文地址https://arxiv.org/abs/2309.17453这里有个有趣的发现只用 Sliding Window 效果会变差但如果保留最开始的几个 Token 就好了。而且这招不需要重新训练模型直接改推理代码就行。实验结果显示Streaming LLM 在长 Sequence 上的表现明显优于纯 Window Attention。06 Pruning KV Cache丢掉没用的 K 和 V更直接的方法来了如果有些 K 和 V 根本用不上直接丢掉不就好了研究发现Attention 其实非常稀疏——大部分 Token 的 attention weight 非常小几乎等于没用上。颜色越深表示 attention weight 越大。可以看到只有很少的 Token 被真正 attention 到了。基于这个观察两篇论文提出了不同的裁剪策略Scissorhandshttps://arxiv.org/abs/2305.17118H2Ohttps://arxiv.org/abs/2306.14048核心思路一致如果一个 K/V 长时间没被 Attention 用到就把它从 Cache 里清除。Scissorhands 的实验显示压缩 5 倍的情况下模型表现跟不压缩基本一样。但 ⚠️ 后续研究也发现如果让模型做非常难的任务随意丢弃 K/V 可能会导致表现大幅下降。这个方法适合常规任务关键场景要谨慎。07 跨对话 CacheAgent 场景的大杀器前面的 KV Cache 都是同一个对话里的优化。但 KV Cache 还有一个更高级的玩法——跨对话共享。不同对话里如果出现相同的文本片段它们的 K 和 V 理论上是可以复用的。什么场景最受益AI Agent 场景是跨对话 Cache 的最佳舞台每个 Agent 调用都带着一串 System Prompt角色设定、工具定义、记忆指令等这些内容不同对话间高度一致。使用技巧要让 Cache Hit 率最大化内容的排列顺序有讲究越稳定不动的内容放越前面越可能变动的内容放越后面。另外同一个 Prompt 用不同写法Cache Hit 率可以差很多换一种写法后Cache Hit 明显提高这意味着直接省钱。有一篇论文专门测了这个效果https://arxiv.org/abs/2601.06007结论是用好的 Prompt 写法结合 Cached InputAgent 的花费可以大幅降低。08 总结一张表看清所有加速方案方法说明改变 Attention需要训练主要代价Flash Attention减少 HBM 读写次数优化计算顺序✗✗一点额外运算KV Cache存储已算好的 K 和 V避免重复计算✗✗占用大量显存Multi-Query Attention多个 Query 头共享一组 K/V✓✓可能明显伤害模型能力Grouped-Query AttentionQuery 分组共享 K/V✓✓效果-效率平衡Multi-head Latent Attention压缩 K/V 后再存储✓✓需要重新训练Sliding Window Attention只 Attention 附近 Token✓?可能丢失长距离信息Streaming LLMSliding Window 保留开头的 Token✗✗—Pruning KV Cache丢弃不常用的 K 和 V✓✗复杂任务可能效果下降Speculative Decoding用小模型预估结果大模型校验✗理论上✗小模型额外算力参考资料Flash Attentionhttps://arxiv.org/abs/2205.14135Multi-head Latent AttentionDeepSeekhttps://arxiv.org/abs/2405.04434Streaming LLMhttps://arxiv.org/abs/2309.17453Scissorhandshttps://arxiv.org/abs/2305.17118H2OHeavy-Hitter Oraclehttps://arxiv.org/abs/2306.14048Cached Input 对 Agent 花费的影响https://arxiv.org/abs/2601.06007