Flash Attention:少搬几趟,比算得更快重要

📅 2026/7/6 3:10:08
Flash Attention:少搬几趟,比算得更快重要
Flash Attention少搬几趟比算得更快重要GPU 的算力每代翻倍但显存带宽十年没怎么涨。所以 GPU 大部分时间不是在算是在等数据搬过来。Flash Attention 做的事情只有一件少搬。标准做法算一步搬一趟Standard vs Flash: 一个 Cycle 级视角不看结构看过程 —— 数据在 HBM 与 SRAM 之间到底怎么跑Standard AttentionCycleALU 状态HBM 读写1等待 Q[i]READ Q[i]2计算 S[i,:]—3等待写回 SWRITE S NxN4等待读回 SREAD S NxN5Softmax(P)—6等待写回 PWRITE P NxN7等待读回 PVREAD PV8O P * VWRITE O[读Q] -- [算S] -- [写S回HBM] -- [读S回来] -- [Softmax] -- [写P回HBM] -- [读PV] -- [算O]8 步中 5 步在等 IOStall 率约 62%5 / 8 cycles 在等搬运中间矩阵 S, P 各写回读回一次 O(N^2) 带宽Flash AttentionCycleSRAM 内操作HBM 读写1加载 TileREAD Tile2S Q * K^T—3Online Softmax—4O P * V—5加载下一 TileREAD Tile6-8SSoftmaxO—9写回 OWRITE O[读Tile] -- [SSoftmaxO 全在SRAM] -- [读下个Tile] -- ... -- [写回O]中间结果不落地仅首尾各一次 HBM 读写仅 2 / 9 cycles 涉及 IOS, P 永远不写回 HBM O(N) 带宽左边的 Standard每算一个中间结果都要写回显存再读回来。8 个 cycle 里 5 个在等。这不是算得慢是搬得慢。Flash 做法搬一次算到底右边。切成 Tile一次搬进 SRAM在 SRAM 里把 S、Softmax、O 全算完再搬下一块。中间结果不落地。搬运量从 O(N²) 降到 O(N)Memory Access: O(N^2) vs O(N)数据实证 —— 随序列长度增长搬运量的天壤之别Memory Access (GB)01020304050Sequence Length N1K2K4K6K8K0.72.811.325.345 GB0.251.02.0 GB22.5xStandard: O(N^2)Flash: O(N)N8K 时45 GB vs 2 GB。22.5 倍仅仅是少搬了几趟。但分块之后Softmax 怎么算传统 Softmax 需要看到整行才能算分母。分块之后只能看到局部。Online Softmax 的解法维护两个局部变量增量更新。row_m到目前为止的最大值row_l到目前为止的累加和每来一个 Tile用旧的 m 和 l 算出新的 m 和 l。不需要全局视野局部状态正确演进就能到达全局正确。Online Softmax 逐步推演m 和 l 的增量进化不看公式看过程 —— row_m 和 row_l 在每个 Tile 到来后到底怎么变StepTile 数据row_m (当前最大值)row_l (累加和)O (累加输出)Init—m -infl 0O [0, 0]Tile 0S [2.0, 1.0]m 2.0max(-inf, 2.0)l 1.370 e^0 e^-1O V0 * attn0m 驱动 lTile 1S [3.0, 0.5]m 3.0max(2.0, 3.0)l 1.881.37*e^-1 e^0 e^-2.5O * exp(-1);O V1 * attn1m 变大 - 旧 O 必须缩放核心洞察当 m 增大时旧的 softmax 分母变小了所以旧的 l 和 O 必须乘以 exp(m_prev - m_new) 来修正这就是为什么 row_m 和 row_l 必须同时维护 —— m 变化时 l 必须联动修正Tile 2S [1.5, 2.5]m 3.0max(3.0, 2.5) 不变l 3.031.88 e^-1.5 e^-0.5O V2 * attn2m 不变无需缩放Final—m 3.0l 3.03O O / l最终除以 l 归一化Online Softmax 伪代码对应 forward_kernelm -inf, l 0, O 0for each Tile j:S_j Q_i * K_j^T // SRAM 内计算m_new max(m, max(S_j)) // 局部最大值更新l_new l * exp(m - m_new) sum(exp(S_j - m_new)) // 修正旧 l 累加新 lO O * exp(m - m_new) exp(S_j - m_new) V_j // 修正旧 O 累加新 Om m_new, l l_new // 提交更新图里 Tile 1 那行是关键m 从 2.0 变成 3.0旧的 l 和 O 必须乘exp(m_prev - m_new)修正。m 变了l 就得跟着变O 也得跟着缩放。这条因果链是 Online Softmax 的全部秘密。对应代码就是row_m_prev → row_m_newrow_l_prev → row_l_new。反向传播重新算比搬回来便宜Flash Attention 反向传播时不存中间矩阵 S而是从 Q、K、V 重新算。看似浪费实则最优A100 的 312 TFLOPS 算力是过剩的2 TB/s 的 HBM 带宽是稀缺的。用富余换稀缺。一句话总结在受限系统里管理移动比管理计算重要。少搬一趟的收益远大于算得快一点。这个规律不只在 GPUCPU 优化 减少 Cache Miss数据库优化 减少磁盘 IO分布式优化 计算向数据靠拢都是同一件事尊重物理约束让数据少跑路。