注:此文章内容均节选自充电了么创始人,CEO兼CTO陈敬雷老师的新书《自然语言处理原理与实战》(人工智能科学与技术丛书)【陈敬雷编著】【清华大学出版社】
文章目录
- DeepSeek大模型技术系列七
- DeepSeek大模型技术系列七》DeepSeek 突破!NSA——DeepSeek 原生稀疏注意力开启硬件适配与可训练新时代
- 更多技术内容
- 总结
DeepSeek大模型技术系列七
DeepSeek大模型技术系列七》DeepSeek 突破!NSA——DeepSeek 原生稀疏注意力开启硬件适配与可训练新时代
长上下文建模对于下一代语言模型至关重要,然而标准注意力机制的高计算成本带来了巨大的计算挑战。稀疏注意力为在保持模型能力的同时提高效率提供了一个有前景的方向。我们提出了 NSA(原生可训练稀疏注意力机制),这是一种将算法创新与硬件适配优化相结合的原生可训练稀疏注意力机制,以实现高效的长上下文建模。NSA 采用动态分层稀疏策略,结合粗粒度令牌压缩和细粒度令牌选择,既保留全局上下文感知,又保持局部精度。我们的方法在稀疏注意力设计方面有两项关键创新:(1)通过算术强度平衡的算法设计实现显著加速,并针对现代硬件进行了实现优化。(2)实现端到端训练,在不牺牲模型性能的情况下减少预训练计算量。如图 1 所示,实验表明,用 NSA 预训练的模型在通用基准测试、长上下文任务和基于指令的推理中,性能保持或超过全注意力模型。同时,在处理 64k 长度的序列时,NSA 在解码、前向传播和反向传播方面比全注意力机制有显著的加速,验证了其在模型整个生命周期中的效率。
1. 引言
受从深度推理(DeepSeek-AI, 2025; Zelikman 等人,2022)、代码库级代码生成(Zhang 等人,2023a; Zhang 等人)到多轮自主智能体系统(Park 等人,2023)等各种实际应用的推动,研究界越来越认识到长上下文建模是下一代大语言模型的关键能力。最近的突破,包括 OpenAI 的 o 系列模型、DeepSeek-R1(DeepSeek-AI, 2025)和 Gemini 1.5 Pro(Google 等人,2024),使模型能够处理整个代码库、长篇文档,在数千个令牌上保持连贯的多轮对话,并跨越长距离依赖进行复杂推理。然而,随着序列长度的增加,普通注意力(Vaswani 等人,2017)机制的高复杂度(Zaheer 等人,2020)成为关键的延迟瓶颈。理论估计表明,在解码 64k 长度的上下文时,具有 softmax 架构的注意力计算占总延迟的 70 - 80%,这凸显了对更高效注意力机制的迫切需求。
图 1|全注意力模型和我们的国家安全局之间的性能和效率比较。左:尽管稀疏国家安全局在一般基准测试、长上下文任务和推理评估中平均超过全注意力基线右:对于 64k 长度的序列处理,国家安全局在所有阶段(解码、前向传播和反向传播)都比全注意力实现了大量的计算加速。
实现高效长上下文建模的一种自然方法是利用 softmax 注意力的固有稀疏性(Ge 等人,2023; Jiang 等人,2023),即选择性地计算关键的查询 - 键对可以在保持性能的同时显著降低计算开销。最近的进展通过多种策略展示了这种潜力:KV 缓存逐出方法(Li 等人,2024; Zhang 等人,2023b; Zhou 等人,2024)、分块 KV 缓存选择方法(Tang 等人,2024; Xiao 等人,2024)以及基于采样、聚类或哈希的选择方法(Chen 等人,2024; Desai 等人,2024; Liu 等人,2024)。尽管有这些有前景的策略,但现有的稀疏注意力方法在实际部署中往往存在不足。许多方法未能实现与理论增益相匹配的加速;此外,大多数方法主要关注推理阶段,缺乏有效的训练阶段支持,无法充分利用注意力的稀疏模式。
为了解决这些限制,有效的稀疏注意力部署必须应对两个关键挑战:(1)硬件适配的推理加速:将理论上的计算减少转化为实际的速度提升,需要在预填充和解码阶段设计对硬件友好的算法,以减轻内存访问和硬件调度瓶颈;(2)考虑训练的算法设计:使用可训练的操作符实现端到端计算,在保持模型性能的同时降低训练成本。这些要求对于实际应用中实现快速长上下文推理或训练至关重要。综合考虑这两个方面,现有方法仍存在明显差距。
图 2|美国国家安全局架构概述。左:框架通过三个并行注意力分支处理输入序列:对于给定的查询,前面的键和值被处理成粗粒度模式的压缩注意力、重要标记块的选定注意力和本地上下文的滑动注意力。右:每个分支产生的不同注意力式的可视化。绿色区域表示需要计算注意力分数的区域,而白色区域表示可以跳过的区域。
为了实现更有效和高效的稀疏注意力,我们提出了 NSA,这是一种集成了分层令牌建模的原生可训练稀疏注意力架构。如图 2 所示,NSA 通过将键和值组织成时间块,并通过三个注意力路径处理它们来减少每个查询的计算量:压缩的粗粒度令牌、选择性保留的细粒度令牌以及用于局部上下文信息的滑动窗口。然后,我们实现了专门的内核以最大化其实际效率。NSA 针对上述关键要求引入了两项核心创新:(1)硬件适配系统:针对张量核心利用率和内存访问优化分块稀疏注意力,确保算术强度平衡。(2)考虑训练的设计:通过高效算法和反向操作符实现稳定的端到端训练。这种优化使 NSA 既支持高效部署,又支持端到端训练。
我们在真实世界的语言语料库上进行了全面的实验来评估 NSA。在具有 270 亿参数的 Transformer 骨干网络上使用 2600 亿个令牌进行预训练,我们评估了 NSA 在通用语言评估、长上下文评估和思维链推理评估中的性能。我们进一步在 A100 GPU 上与经过优化的 Triton(Tillet 等人,2019)实现进行内核速度比较。实验结果表明,NSA 的性能与全注意力基线相当或更优,同时优于现有的稀疏注意力方法。此外,与全注意力相比,NSA 在解码、前向和反向阶段都实现了显著加速,并且序列越长加速比越高。这些结果验证了我们的分层稀疏注意力设计有效地平衡了模型能力和计算效率。
2. 重新思考稀疏注意力方法
现代稀疏注意力方法在降低 Transformer 模型的理论计算复杂度方面取得了显著进展。然而,大多数方法主要在推理阶段应用稀疏性,同时保留预训练的全注意力骨干网络,这可能会引入架构偏差,限制它们充分利用稀疏注意力优势的能力。在介绍我们的原生稀疏架构之前,我们从两个关键方面系统地分析了这些限制。
2.1 高效推理的错觉
尽管在注意力计算中实现了稀疏性,但许多方法未能相应地减少推理延迟,主要原因有两个:
阶段受限的稀疏性:像 H2O(Zhang 等人,2023b)这样的方法在自回归解码时应用稀疏性,而在预填充阶段需要进行计算密集型的预处理(例如注意力图计算、索引构建)。相比之下,MInference(Jiang 等人,2024)等方法只关注预填充阶段的稀疏性。这些方法无法在所有推理阶段实现加速,因为至少有一个阶段的计算成本与全注意力相当。这种阶段特定性降低了这些方法在以预填充为主的工作负载(如书籍摘要和代码补全)或以解码为主的工作负载(如长思维链(Wei 等人,2022)推理)中的加速能力。
与先进注意力架构不兼容:一些稀疏注意力方法无法适应现代高效解码架构,如多查询注意力(MQA)(Shazeer, 2019)和分组查询注意力(GQA)(Ainslie 等人,2023),这些架构通过在多个查询头之间共享 KV 显著减少了解码期间的内存访问瓶颈。例如,在 Quest(Tang 等人,2024)等方法中,每个注意力头独立选择其 KV 缓存子集。虽然在多头注意力(MHA)模型中它表现出一致的计算稀疏性和内存访问稀疏性,但在基于 GQA 等架构的模型中情况不同,在 GQA 中,KV 缓存的内存访问量对应于同一 GQA 组内所有查询头选择的并集。这种架构特征意味着,虽然这些方法可以减少计算操作,但所需的 KV 缓存内存访问仍然相对较高。这一限制带来了一个关键选择:一些稀疏注意力方法在减少计算的同时,其分散的内存访问模式与先进架构的高效内存访问设计相冲突。
这些限制的出现是因为许多现有的稀疏注意力方法专注于减少 KV 缓存或理论计算量,但在先进框架或后端中难以显著降低延迟。这促使我们开发结合先进架构和硬件高效实现的算法,以充分利用稀疏性来提高模型效率。
2.2 可训练稀疏性的误区
通过分析仅用于推理的方法,我们追求原生可训练稀疏注意力的动机源于两个关键见解:(1)性能下降:事后应用稀疏性会迫使模型偏离其预训练的优化轨迹。正如 Chen 等人(2024)所示,前 20% 的注意力只能覆盖总注意力分数的 70%,这使得预训练模型中的检索头等结构在推理过程中容易被剪枝。(2)训练效率需求:高效处理长序列训练对于现代大语言模型的发展至关重要。这包括在更长的文档上进行预训练以增强模型能力,以及后续的适应阶段,如长上下文微调强化学习。然而,现有的稀疏注意力方法主要针对推理,在很大程度上未解决训练中的计算挑战。这一限制阻碍了通过高效训练开发更强大的长上下文模型。此外,尝试将现有的稀疏注意力方法应用于训练也暴露出一些挑战:
不可训练的组件:像 ClusterKV(Liu 等人,2024)(包括 k 均值聚类)和 MagicPIG(Chen 等人,2024)(包括基于 SimHash 的选择)等方法中的离散操作会在计算图中产生不连续性。这些不可训练的组件阻止了梯度在令牌选择过程中流动,限制了模型学习最优稀疏模式的能力。
低效的反向传播:一些理论上可训练的稀疏注意力方法在实际训练中效率低下。例如 HashAttention(Desai 等人,2024)中使用的令牌粒度选择策略导致在注意力计算期间需要从 KV 缓存中加载大量单个令牌。这种不连续的内存访问使得像 FlashAttention 这样依赖连续内存访问和分块计算来实现高吞吐量的快速注意力技术无法有效应用。结果,实现被迫退回到低硬件利用率状态,显著降低了训练效率。
2.3 原生稀疏性的必要性
推理效率和训练可行性方面的这些限制促使我们对稀疏注意力机制进行根本性的重新设计。我们提出了 NSA,这是一种原生稀疏注意力框架,解决了计算效率和训练要求这两个问题。在以下部分,我们将详细介绍 NSA 的算法设计和操作符实现。
3. 方法
我们的技术方法涵盖算法设计和内核优化。在以下小节中,我们首先介绍我们方法的背景。然后介绍 NSA 的整体框架,接着介绍其关键算法组件。最后,我们详细介绍针对硬件优化的内核设计,以最大化实际效率。
3.1 背景
注意力机制广泛应用于语言建模中,其中每个查询令牌 q 会针对所有前面的键 计算相关性分数,以生成值 的加权和。形式上,对于长度为 t 的输入序列,注意力操作定义为:
其中 Attn 表示注意力函数:
随着序列长度的增加,注意力计算在整体计算成本中所占的比重越来越大,这给长上下文处理带来了重大挑战。
算术强度是计算操作数与内存访问数的比率。它本质上影响着硬件上的算法优化。每个 GPU 都有一个由其峰值计算能力和内存带宽决定的关键算术强度,通过这两个硬件限制的比率来计算。对于计算任务,算术强度高于这个关键阈值时,计算受限于 GPU 的浮点运算次数(FLOPS);低于这个阈值时,计算受限于内存带宽。
具体来说,对于因果自注意力机制,在训练和预填充阶段,批量矩阵乘法和注意力计算具有较高的算术强度,这使得这些阶段在现代加速器上受限于计算能力。相比之下,自回归解码受限于内存带宽,因为它在每次前向传递时只生成一个令牌,但需要加载整个键值缓存,导致算术强度较低。这就导致了不同的优化目标:在训练和预填充阶段降低计算成本,在解码阶段减少内存访问。
3.2 整体框架
为了利用注意力的自然稀疏模式的潜力,我们建议用一组更紧凑、信息更密集的表示键值对 、 来替换公式(1)中的原始键值对 、 。具体来说,我们正式定义优化后的注意力输出如下:
其中 、 是基于当前查询 和上下文记忆 、 动态构建的。我们可以设计各种映射策略来获得不同类别的 、 ,并将它们组合如下:
如图 2 所示,NSA 有三种映射策略 ,分别代表键和值的压缩、选择和滑动窗口。 是相应策略 c 的门控分数,通过一个 MLP 和 sigmoid 激活函数从输入特征中得出。令 表示重新映射的键 / 值的总数:
我们通过确保 来保持较高的稀疏率。
3.3 算法设计
在本小节中,我们将介绍重新映射策略 和 的设计:令牌压缩、令牌选择和滑动窗口。
3.3.1 令牌压缩
通过将连续的键或值块聚合为块级表示,我们得到了能够捕获整个块信息的压缩键和值。形式上,压缩键表示定义为:
其中 l 是块长度,d 是相邻块之间的滑动步长, 是一个带有块内位置编码的可学习 MLP,用于将块中的键映射到单个压缩键。 是由压缩键组成的张量。通常,我们采用 来减少信息碎片化。类似的公式适用于压缩值表示 。压缩表示捕获更粗粒度的高级语义信息,并减少注意力的计算负担。
3.3.2 令牌选择
仅使用压缩键和值可能会丢失重要的细粒度信息,因此我们需要选择性地保留单个键和值。下面我们将介绍一种高效的令牌选择机制,它能够以较低的计算开销识别并保留最相关的令牌。
分块选择:我们的选择策略以空间连续的块为单位处理键和值序列,这是出于两个关键因素的考虑:硬件效率和注意力分数的固有分布模式。分块选择对于在现代 GPU 上实现高效计算至关重要。这是因为现代 GPU 架构在连续块访问时的吞吐量明显高于随机索引读取。此外,分块计算能够优化张量核心的利用率。这种架构特性使得分块内存访问和计算成为高性能注意力实现的基本原则,FlashAttention 的分块设计就是一个例子。分块选择遵循注意力分数的固有分布模式。先前的研究(Jiang 等人,2024)表明,注意力分数通常表现出空间连续性,这意味着相邻的键往往具有相似的重要性水平。我们在 6.2 节中的可视化也展示了这种空间连续模式。为了实现分块选择,我们首先将键和值序列划分为选择块。
重要性分数计算:计算块重要性分数可能会带来显著的开销。幸运的是,压缩令牌的注意力计算会产生中间注意力分数,我们可以利用这些分数来推导选择块的重要性分数,公式为:
其中是与压缩键之间的注意力分数。设表示选择块的大小。当压缩块和选择块的划分方案相同时,即,我们可以直接得到选择块的重要性分数,即 。对于划分方案不同的情况,我们根据它们的空间关系推导选择块的重要性分数。假设且,我们有:
其中表示访问向量元素的索引操作符。对于采用 GQA 或 MQA 的模型,其中键值缓存是跨查询头共享的,必须确保这些头之间的块选择一致,以最小化解码期间的 KV 缓存加载。同一组内跨头的共享重要性分数正式定义为:
其中上标表示头索引,是每组中的查询头数量。这种聚合确保了同一组内各头之间块选择的一致性。
Top - n 块选择:在获得选择块的重要性分数后,我们保留按块重要性分数排名前的稀疏块内的令牌,公式为:
其中表示降序排列的排名位置,对应最高分,是选定块的索引集合,表示拼接操作。是由压缩键组成的张量。类似的公式适用于细粒度值 。然后,选定的键和值按照公式(5)与参与注意力计算。
3.3.3 滑动窗口
在注意力机制中,局部模式通常适应得更快,并且可能主导学习过程,这可能会阻止模型有效地从压缩和选择的令牌中学习。为了解决这个问题,我们引入了一个专门的滑动窗口分支,明确处理局部上下文,使其他分支(压缩和选择)能够专注于学习各自的特征,而不会被局部模式干扰。具体来说,我们在窗口中维护最近的令牌 、,并将不同信息源(压缩令牌、选定令牌、滑动窗口)的注意力计算隔离到单独的分支中。这些分支的输出通过一个学习到的门控机制进行聚合。为了在几乎不增加计算开销的情况下进一步防止注意力分支之间的捷径学习,我们为三个分支提供独立的键和值。这种架构设计通过防止局部和长距离模式识别之间的梯度干扰,实现了稳定的学习,同时引入的开销最小。
在获得所有三类键和值(,;,;以及,)后,我们按照公式(5)计算最终的注意力输出。结合上述的压缩、选择和滑动窗口机制,这就构成了 NSA 完整的算法框架。
3.4 内核设计
为了在训练和预填充阶段实现与 FlashAttention 相当的加速效果,我们基于 Triton 实现了硬件适配的稀疏注意力内核。鉴于多头注意力(MHA)在解码时内存需求大且效率低下,我们遵循当前最先进的大语言模型,重点关注具有共享 KV 缓存的架构,如 GQA 和 MQA。虽然压缩和滑动窗口注意力计算很容易与现有的 FlashAttention - 2 内核兼容,但我们针对稀疏选择注意力引入了专门的内核设计。如果我们遵循 FlashAttention 将时间上连续的查询块加载到 SRAM 的策略,由于块内的查询可能需要不连续的 KV 块,这将导致内存访问效率低下。为了解决这个问题,我们的关键优化在于采用不同的查询分组策略:对于查询序列上的每个位置,我们将 GQA 组内的所有查询头(它们共享相同的稀疏 KV 块)加载到 SRAM 中。
图 3 展示了我们的前向传递实现。所提出的内核架构具有以下关键特征:
1.以组为中心的数据加载:对于每个内循环,在位置加载组内所有头的查询以及它们共享的稀疏键 / 值块索引。
2.共享 KV 获取:在内循环中,按顺序将由索引的连续键 / 值块作为、加载到 SRAM 中,以最小化内存加载,其中是满足的内核块大小。
3.网格上的外循环:由于不同查询块的内循环长度(与选定块的数量成正比)几乎相同,我们将查询 / 输出循环放入 Triton 的网格调度器中,以简化和优化内核。
图 3INSA 的内核设计。内核按 GQA 组(网格循环)加载查询,获取相应的稀疏 KV 块(内部循环),并在 SRAM 上执行注意力计算。绿色块表示 SRAM 上的数据,而蓝色块表示 HBM 上的数据。
这种设计通过(1)通过组内共享消除冗余的 KV 传输,以及(2)平衡 GPU 流式多处理器之间的计算负载,实现了接近最优的算术强度。
4. 实验
我们从三个方面评估 NSA:(1)通用基准测试性能;(2)长上下文基准测试性能;(3)思维链推理性能,并与全注意力基线和最先进的稀疏注意力方法进行比较。我们将稀疏计算范式的效率分析推迟到第 5 节,在那里我们将详细讨论训练和推理速度。
4.1 预训练设置
遵循当前最先进的大语言模型的常见做法,我们的实验采用了结合分组查询注意力(GQA)和专家混合(MoE)的骨干网络,总参数为 270 亿,其中活跃参数为 30 亿。该模型由 30 层组成,隐藏层维度为 2560。对于 GQA,我们将组数设置为 4,总共有 64 个注意力头。对于每个头,查询、键和值的隐藏维度分别配置为和。对于 MoE,我们采用 DeepSeekMoE(Dai 等人,2024; DeepSeek - AI, 2024)结构,有 72 个路由专家和 2 个共享专家,并将 top - k 专家设置为 6。为了确保训练的稳定性,第一层的 MoE 被替换为 SwiGLU 形式的 MLP。
图 4|全注意力和我们的 NSA 在 278 参数模型上的预训练损失比较。两种模型都表现出稳定的收敛性,NSA 实现了较低的损失值。
我们为 NSA 设置压缩块大小、滑动步长、选定块大小、选定块数量(包括固定激活的 1 个初始块和 2 个局部块)以及滑动窗口大小。全注意力模型和稀疏注意力模型都在 8k 长度文本的 2700 亿个令牌上进行预训练,然后使用 YaRN(Peng 等人,2024)在 32k 长度的文本上继续训练和监督微调,以实现长上下文适应。两个模型都训练至完全收敛,以确保公平比较。如图 4 所示,我们的 NSA 和全注意力基线的预训练损失曲线均稳定下降,且 NSA 始终优于全注意力模型。
4.2 基线方法
除了与全注意力进行比较外,我们还评估了几种最先进的推理阶段稀疏注意力方法:H2O(Zhang 等人,2023b)、infLLM(Xiao 等人,2024)、Quest(Tang 等人,2024)和 Exact - Top(Exact - Top 首先计算全注意力分数,为每个查询选择前个相应的键,然后在这些位置上计算注意力)。这些方法涵盖了不同的稀疏注意力范式,包括 KV 缓存逐出、查询感知选择和精确的 top - n 稀疏选择。
对于通用评估,由于大多数样本的长度在稀疏注意力基线的局部上下文窗口内,这些方法实际上与全注意力相当。因此,在这种情况下,我们仅展示 NSA 与全注意力基线的比较结果。在长上下文评估中,我们对所有基线方法进行比较,并将所有稀疏注意力方法的稀疏度设置为相同,以确保公平比较。对于需要长文本监督微调的思维链推理评估,我们仅与全注意力进行比较,因为稀疏注意力基线不支持训练。
4.3 性能比较
通用评估:我们在一系列涵盖知识、推理和编码能力的基准测试中评估了预训练的 NSA 和全注意力基线,包括 MMLU(Hendrycks 等人,2020)、MMLU - PRO(Wang 等人,2024)、CMMLU(Li 等人,2023)、BBH(Suzgun 等人,2022)、GSM8K(Cobbe 等人,2021)、MATH(Hendrycks 等人,2020)、DROP(Dua 等人,2019)、MBPP(Austin 等人,2021)和 HumanEval(Chen 等人,2021)。结果如表 1 所示。尽管 NSA 具有稀疏性,但它在整体性能上更优,在 9 个指标中有 7 个超过了包括全注意力在内的所有基线。这表明,虽然 NSA 在较短序列上可能无法充分发挥其效率优势,但它仍表现出强大的性能。值得注意的是,NSA 在与推理相关的基准测试中表现出显著的提升(DROP:提高 0.042,GSM8K:提高 0.034),这表明我们的预训练有助于模型开发专门的注意力机制。这种稀疏注意力预训练机制迫使模型专注于最重要的信息,可能通过过滤掉无关注意力路径中的噪声来提高性能。在各种评估中的一致表现也验证了 NSA 作为通用架构的稳健性。
长上下文评估:如图 5 所示,NSA 在 64k 上下文的 “大海捞针”(Kamradt, 2023)测试中,所有位置的检索准确率均达到完美。这种性能源于我们的分层稀疏注意力设计,它结合了用于高效全局上下文扫描的压缩令牌和用于精确局部信息检索的选择令牌。粗粒度压缩以较低的计算成本识别相关的上下文块,而对选定令牌的令牌级注意力确保了关键细粒度信息的保留。这种设计使 NSA 能够同时保持全局感知和局部精度。
我们还在 LongBench(Bai 等人,2023)上对 NSA 与最先进的稀疏注意力方法和全注意力基线进行了评估。为了确保稀疏度一致,我们将所有稀疏注意力基线中每个查询激活的令牌数设置为 2560,这与 NSA 处理 32k 序列长度时激活的平均令牌数相对应。遵循 StreamLLM(Xiao 等人,2023)的设置,这个令牌预算包括前 128 个令牌和 512 个局部令牌。由于某些子集在所有模型上的得分都较低,可能无法提供有意义的比较,我们将其从 LongBench 中排除。如表 2 所示,NSA 获得了最高的平均得分 0.469,超过了所有基线(比全注意力高 0.032,比 Exact - Top 高 0.046)。这一改进源于两项关键创新:(1)我们的原生稀疏注意力设计,它允许在预训练期间对稀疏模式进行端到端优化,促进了稀疏注意力模块与模型其他组件之间的同步适应;(2)分层稀疏注意力机制在局部和全局信息处理之间实现了平衡。
值得注意的是,NSA 在需要长上下文复杂推理的任务上表现出色,在多跳问答任务(HPQ 和 2Wiki)上比全注意力提高了 0.087 和 0.051,在代码理解(LCC:提高 0.069)方面超过了基线,在段落检索(PassR - en:提高 0.075)方面也优于其他方法。这些结果验证了 NSA 处理各种长上下文挑战的能力,其原生预训练的稀疏注意力在学习任务最优模式方面提供了额外的优势。
思维链推理评估:为了评估 NSA 与先进的下游训练范式的兼容性,我们研究了它通过训练后学习获得思维链数学推理能力的能力。鉴于强化学习在较小规模模型上的效果有限,我们采用了来自 DeepSeek - R1 的知识蒸馏,使用 100 亿个 32k 长度的数学推理轨迹进行监督微调(SFT)。这产生了两个可比较的模型:全注意力 - R(全注意力基线)和 NSA - R(我们的稀疏变体)。我们在具有挑战性的美国数学邀请赛(AIME 24)基准上评估这两个模型。我们使用 0.7 的采样温度和 0.95 的 top - p 值为每个问题生成 16 个答案,并获得平均得分。为了验证推理深度的影响,我们进行了两个生成上下文限制的实验:8k 和 16k 令牌,以测量扩展推理链是否能提高准确率。模型预测的示例比较见附录 A。
如表 3 所示,在 8k 上下文设置下,NSA - R 的准确率显著高于全注意力 - R(提高 0.075),在 16k 上下文下这一优势仍然存在(提高 0.054)。这些结果验证了原生稀疏注意力的两个关键优势:(1)预训练的稀疏注意力模式能够有效捕获复杂数学推导中至关重要的长距离逻辑依赖;(2)我们架构的硬件适配设计保持了足够的上下文密度,以支持不断增加的推理深度,而不会出现灾难性遗忘。在不同上下文长度下的一致优势证实了,当原生稀疏注意力集成到训练流程中时,它在先进推理任务中的可行性。
5. 效率分析
我们在一个 8 - GPU 的 A100 系统上评估 NSA 与全注意力的计算效率。在效率分析中,我们还将模型配置为 GQA 组,每组头数,查询 / 键维度,值维度。遵循第 4 节中的相同设置,我们将 NSA 的压缩块大小设置为,滑动步长,选定块大小,选定块数量,滑动窗口大小。
5.1 训练速度
我们将基于 Triton 实现的 NSA 注意力和全注意力与基于 Triton 的 FlashAttention - 2 进行比较,以确保在相同后端上进行公平的速度比较。如图 6 所示,随着上下文长度的增加,NSA 的加速比逐渐增大,在 64k 上下文长度下,前向加速比可达 9.0 倍,反向加速比可达 6.0 倍。值得注意的是,序列越长,速度优势越明显。这种加速源于我们的硬件适配算法设计,以最大化稀疏注意力架构的效率:(1)分块内存访问模式通过合并加载最大化了张量核心的利用率;(2)内核中精细的循环调度消除了冗余的 KV 传输。
5.2 解码速度
注意力的解码速度主要由内存访问瓶颈决定,这与 KV 缓存加载量密切相关。在每个解码步骤中,我们的 NSA 最多只需加载个压缩令牌、个选定令牌和个相邻令牌,其中是缓存的序列长度。如表 4 所示,随着解码长度的增加,我们的方法延迟显著降低,在 64k 上下文长度下加速比可达 11.6 倍。这种内存访问效率的优势也随着序列长度的增加而放大。
6. 讨论
在本节中,我们回顾 NSA 的开发过程,并讨论从探索不同稀疏注意力策略中获得的关键见解。虽然我们的方法取得了有前景的结果,但了解替代策略所面临的挑战以及分析注意力模式,为未来的研究方向提供了有价值的背景。我们首先探讨促使我们进行设计选择的替代令牌选择策略所面临的挑战,接着通过可视化来深入了解注意力分布模式。
6.1 替代令牌选择策略面临的挑战
在设计 NSA 之前,我们尝试将现有的稀疏注意力方法应用于训练阶段,但这些尝试遇到了各种挑战,促使我们设计一种不同的稀疏注意力架构:
基于关键聚类的策略:我们研究了像 ClusterKV(Liu 等人,2024)这样基于聚类的策略。这些方法将同一聚类中的键和值存储在连续的内存区域中。虽然在理论上适用于训练和推理,但它们面临三个重大挑战:(1)动态聚类机制会带来不可忽视的计算开销;(2)聚类间的不平衡加剧了算子优化的难度,尤其是在专家混合(MoE)系统中,不均衡的专家并行(EP)组执行时间会导致持续的负载不平衡;(3)由于需要定期进行强制重新聚类和分块顺序训练协议,存在实现上的限制。这些因素综合起来形成了巨大的瓶颈,显著限制了它们在实际部署中的有效性。
其他分块选择策略:我们还考虑了与 NSA 不同的分块键值选择策略,如 Quest(Tang 等人,2024)和 InfLLM(Xiao 等人,2024)。这些方法依赖于为每个块计算一个重要性分数,并根据与的相似度选择前个块 。然而,现有方法面临两个关键问题:(1)由于选择操作不可微,基于神经网络的重要性分数计算依赖于辅助损失,这增加了算子开销,并且常常会降低模型性能;(2)启发式的无参数重要性分数计算策略召回率较低,导致性能次优。我们在一个具有相似架构的 30 亿参数模型上评估了这两种方法,并将它们的损失曲线与 NSA 和全注意力进行比较。对于基于辅助损失的选择方法,我们为每个块引入额外的查询和代表性键来估计块重要性分数,这些分数由每个块内原始查询和键之间的平均注意力分数进行监督。对于启发式无参数选择方法,我们按照 Quest 的策略,使用查询与键块的坐标最小 - 最大值的乘积进行直接选择,不引入额外参数。我们还探索了一种冷启动训练方法,即在最初的 1000 步使用全注意力,然后切换到启发式分块选择。如图 7 所示,这两种方法的损失都较差。
6.2 可视化
为了探索 Transformer 注意力分布中的潜在模式,并为我们的设计寻求灵感,我们在图 8 中可视化了预训练的 270 亿参数全注意力模型的注意力图。可视化结果揭示了有趣的模式,注意力分数倾向于呈现分块聚类的特征,相邻的键通常具有相似的注意力分数。这一观察结果启发了 NSA 的设计,表明基于空间连续性选择关键块可能是一种有前途的方法。分块聚类现象表明,序列中相邻的令牌可能与查询令牌共享某些语义关系,尽管这些关系的具体性质需要进一步研究。这一观察促使我们探索一种基于连续令牌块而非单个令牌的稀疏注意力机制,旨在提高计算效率并保留高注意力模式。
7. 相关工作
我们回顾了现有的通过稀疏注意力提高注意力计算效率的方法。这些方法根据其核心策略大致可分为三类:(1)固定稀疏模式;(2)动态令牌剪枝;(3)查询感知选择。我们从每类中介绍几个代表性的工作。
7.1 固定稀疏模式
滑动窗口(SlidingWindow)是一种常用的方法,它允许查询仅在固定窗口内计算注意力。StreamingLLM(Xiao 等人,2023)通过维护上下文的两个关键部分:注意力汇聚(早期令牌)和局部上下文窗口,来解决处理长文本流的挑战。虽然这些方法有效地降低了内存和计算成本,但它们忽略上下文的固定模式限制了其在需要全面理解上下文的任务上的性能。
7.2 动态令牌剪枝
H2O(Zhang 等人,2023b)实现了一种自适应方法,在解码期间减少 KV 缓存的内存使用。该方法根据注意力分数,基于令牌的近期效用动态地逐出对未来预测不太重要的令牌。SnapKV(Li 等人,2024)也引入了一种令牌剪枝策略,通过选择性地保留最关键的特征来减少 KV 缓存,实现高效的内存使用。SnapKV 在预填充期间通过注意力权重分析和投票识别重要特征,然后通过将选定的压缩特征与最近的上下文相结合来更新 KV 缓存,以保持提示的一致性。
7.3 查询感知选择
Quest(Tang 等人,2024)采用分块选择策略,通过查询与键块的坐标最小 - 最大值的乘积来估计每个块的重要性。结果分数有助于选择前个重要的键值块进行注意力计算。InfLLM(Xiao 等人,2024)通过维护注意力汇聚、局部上下文和可检索块,将固定模式与检索相结合。该方法从每个块中选择代表性键来估计块的重要性。HashAttention(Desai 等人,2024)通过使用学习函数将查询和键映射到汉明空间,将关键令牌识别问题转化为推荐问题。ClusterKV(Liu 等人,2024)通过首先对键进行聚类,然后根据查询 - 聚类相似度选择最相关的聚类进行注意力计算来实现稀疏性。
8. 结论
我们提出了 NSA,这是一种用于高效长上下文建模的硬件适配稀疏注意力架构。通过在可训练架构中集成分层令牌压缩和分块令牌选择,我们的架构在保持全注意力性能的同时,实现了训练和推理的加速。NSA 在通用基准测试中与全注意力基线性能相当,在长上下文评估中超越了现有模型能力,增强了推理能力,同时显著降低了计算延迟并实现了大幅加速,推动了该领域的发展。
更多技术内容
更多技术内容可参见
《自然语言处理原理与实战》(人工智能科学与技术丛书)【陈敬雷编著】【清华大学出版社】书籍。
更多的技术交流和探讨也欢迎加我个人微信chenjinglei66。
总结
此文章有对应的配套新书教材和视频:
【配套新书教材】
《自然语言处理原理与实战》(人工智能科学与技术丛书)【陈敬雷编著】【清华大学出版社】
新书特色:本书从自然语言处理基础开始,逐步深入各种NLP热点前沿技术,使用了Java和Python两门语言精心编排了大量代码实例,契合公司实际工作场景技能,侧重实战。
全书共分为19章,详细讲解中文分词、词性标注、命名实体识别、依存句法分析、语义角色标注、文本相似度算法、语义相似度计算、词频-逆文档频率(TF-IDF)、条件随机场、新词发现与短语提取、搜索引擎Solr Cloud和Elasticsearch、Word2vec词向量模型、文本分类、文本聚类、关键词提取和文本摘要、自然语言模型(Language Model)、分布式深度学习实战等内容,同时配套完整实战项目,例如对话机器人实战、搜索引擎项目实战、推荐算法系统实战。
本书理论联系实践,深入浅出,知识点全面,通过阅读本书,读者不仅可以理解自然语言处理的知识,还能通过实战项目案例更好地将理论融入实际工作中。
《分布式机器学习实战》(人工智能科学与技术丛书)【陈敬雷编著】【清华大学出版社】
新书特色:深入浅出,逐步讲解分布式机器学习的框架及应用配套个性化推荐算法系统、人脸识别、对话机器人等实战项目。
【配套视频】
推荐系统/智能问答/人脸识别实战 视频教程【陈敬雷】
视频特色:把目前互联网热门、前沿的项目实战汇聚一堂,通过真实的项目实战课程,让你快速成为算法总监、架构师、技术负责人!包含了推荐系统、智能问答、人脸识别等前沿的精品课程,下面分别介绍各个实战项目:
1、推荐算法系统实战
听完此课,可以实现一个完整的推荐系统!下面我们就从推荐系统的整体架构以及各个子系统的实现给大家深度解密来自一线大型互联网公司重量级的实战产品项目!
2、智能问答/对话机器人实战
由浅入深的给大家详细讲解对话机器人项目的原理以及代码实现、并在公司服务器上演示如何实际操作和部署的全过程!
3、人脸识别实战
从人脸识别原理、人脸识别应用场景、人脸检测与对齐、人脸识别比对、人脸年龄识别、人脸性别识别几个方向,从理论到源码实战、再到服务器操作给大家深度讲解!
自然语言处理NLP原理与实战 视频教程【陈敬雷】
视频特色:《自然语言处理NLP原理与实战》包含了互联网公司前沿的热门算法的核心原理,以及源码级别的应用操作实战,直接讲解自然语言处理的核心精髓部分,自然语言处理从业者或者转行自然语言处理者必听视频!
人工智能《分布式机器学习实战》 视频教程【陈敬雷】
视频特色:视频核心内容有互联网公司大数据和人工智能、大数据算法系统架构、大数据基础、Python编程、Java编程、Scala编程、Docker容器、Mahout分布式机器学习平台、Spark分布式机器学习平台、分布式深度学习框架和神经网络算法、自然语言处理算法、工业级完整系统实战(推荐算法系统实战、人脸识别实战、对话机器人实战)。
上一篇:DeepSeek大模型技术系列三》DeepSeek-R1:通过强化学习激发大语言模型的推理能力
下一篇:DeepSeek大模型技术系列五》DeepSeek大模型基础设施全解析:支撑万亿参数模型的幕后英雄