03 RoPE

📅 2026/6/20 12:45:40
03 RoPE
03 RoPE1. 为什么需要 RoPE1.1 位置编码要解决什么问题Transformer 的 Attention 机制本身是置换不变的——把输入序列打乱Attention 的输出在没有位置信息的情况下也会被打乱。模型不知道第一个 token和第五个 token有什么区别。所以需要在输入中注入位置信息。1.2 绝对位置编码的痛点早期方案如原始 Transformer 的正弦波位置编码、GPT-2 的可学习位置编码都是绝对位置编码inputtoken_embeddingposition_embedding\text{input} \text{token\_embedding} \text{position\_embedding}inputtoken_embeddingposition_embedding致命问题模型在训练时只见过位置 0~40954K 上下文推理时给一个位置 5000 的 token位置编码是没见过的——模型直接懵了。这就是上下文长度外推Context Extension问题的根源。1.3 RoPE 的核心洞察不直接告诉模型这是第几个 token而是让 Attention 计算自然包含 token 之间的相对距离。具体做法对 Query 和 Key 向量施加一个旋转旋转角度正比于 token 的位置编号。当计算内积⟨qm,kn⟩\langle q_m, k_n \rangle⟨qm​,kn​⟩时结果只依赖于相对位置(m−n)(m - n)(m−n)。q_m · k_n f(内容相似度, m - n) ← 只和相对距离有关 ↑ 不是绝对位置 m 或 n这带来了一个重要能力训练时用 4K 序列推理时可以外推到 16K 甚至 128K配合后续的 RoPE Scaling 技术。2. 数学原理借用复数的旋转2.1 二维旋转的数学把一个二维向量(x1,x2)(x_1, x_2)(x1​,x2​)旋转角度θ\thetaθ[x1′x2′][cos⁡θ−sin⁡θsin⁡θcos⁡θ][x1x2]\begin{bmatrix} x_1 \\ x_2 \end{bmatrix} \begin{bmatrix} \cos\theta -\sin\theta \\ \sin\theta \cos\theta \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix}[x1′​x2′​​][cosθsinθ​−sinθcosθ​][x1​x2​​]用复数可以写得更优雅zx1ix2z x_1 i x_2zx1​ix2​旋转θ\thetaθ就是z′z⋅eiθz z \cdot e^{i\theta}z′z⋅eiθ。其中eiθcos⁡θisin⁡θe^{i\theta} \cos\theta i\sin\thetaeiθcosθisinθ欧拉公式。2.2 推广到高维分组旋转实际的 Query/Key 向量是ddd维的比如d128d128d128。RoPE 的做法是把ddd维向量两两配对分成d/2d/2d/2组每组2 维做一个独立的旋转第iii组的旋转角度是θi10000−2i/d\theta_i 10000^{-2i/d}θi​10000−2i/d对于位置mmm的 token第iii组旋转m⋅θim \cdot \theta_im⋅θi​为什么不同维度组旋转速度不同θi\theta_iθi​随iii指数衰减——低维度组旋转快高频高维度组旋转慢低频。这模拟了短距离信息靠高频捕捉长距离信息靠低频捕捉的信号处理直觉。2.3 频率预计算# 频率公式θ_i 10000^{-2i/d}i 0, 1, ..., d/2-1freqs1.0/(10000**(torch.arange(0,dim,2)[:dim//2].float()/dim))# 结果freqs [1/1.00, 1/1.15, 1/1.33, ..., 1/10000]# 高频 ←――――――――――――――――→ 低频对于每个位置mmm0 到 seq_len-1该位置的旋转角度矩阵是m * freqs。用极坐标生成复数# torch.polar(abs, angle) 生成 abs * e^{i * angle}freqs_cistorch.polar(torch.ones_like(angles),angles)# shape: [seq_len, dim//2]freqs_cis[m, i]就是位置 m 在第 i 组维度的旋转复数ei⋅m⋅θie^{i \cdot m \cdot \theta_i}ei⋅m⋅θi​。2.4 应用旋转复数乘法把 Query 的最后一维dddreshape 成[d/2,2][d/2, 2][d/2,2]然后用torch.view_as_complex解释为d/2d/2d/2个复数# xq shape: [B, L, num_heads, d]# Step 1: reshape → [B, L, num_heads, d/2, 2]# Step 2: view_as_complex → [B, L, num_heads, d/2] (复数张量)xq_complextorch.view_as_complex(xq.reshape(*xq.shape[:-1],-1,2))然后广播复数乘法xq_rotated_complexxq_complex*freqs_cis# 复数旋转# 再转回实数xq_rotatedtorch.view_as_real(xq_rotated_complex).flatten(3)复数乘法自动实现了旋转矩阵(abi)×(cos⁡θisin⁡θ)(acos⁡θ−bsin⁡θ)i(asin⁡θbcos⁡θ)(abi) \times (\cos\theta i\sin\theta) (a\cos\theta - b\sin\theta) i(a\sin\theta b\cos\theta)(abi)×(cosθisinθ)(acosθ−bsinθ)i(asinθbcosθ)这正好等于[cos⁡θ−sin⁡θsin⁡θcos⁡θ][ab]\begin{bmatrix} \cos\theta -\sin\theta \\ \sin\theta \cos\theta \end{bmatrix} \begin{bmatrix} a \\ b \end{bmatrix}[cosθsinθ​−sinθcosθ​][ab​]。3. 代码实现3.1 预计算频率表defprecompute_freqs_cis(dim:int,end:int,theta:float10000.0): 计算复数指数频率张量。 返回 shape: [end, dim//2]dtypecomplex64 freqs_cis[m, i] e^{i * m * θ_i} # Step 1: 计算每个维度组的频率 θ_i 10000^{-2i/d}freqs1.0/(theta**(torch.arange(0,dim,2)[:(dim//2)].float()/dim))# Step 2: 每个位置 m 对应的角度 m * θ_ittorch.arange(end,dtypetorch.float32)# [end]anglestorch.outer(t,freqs)# [end, dim//2]# Step 3: 用极坐标生成复数 e^{i * angle}freqs_cistorch.polar(torch.ones_like(angles),angles)returnfreqs_cistorch.outer的作用把t[0..seq_len]和freqs[0..d/2]做外积得到[seq_len, d/2]的角度矩阵。每个元素angles[m, i] m * θ_i。3.2 应用旋转编码 —— 关键必须 FP32 Upcastdefapply_rotary_emb(xq,xk,freqs_cis): 对 Query 和 Key 施加 RoPE 旋转。 xq, xk: [B, L, num_heads, head_dim] freqs_cis: [seq_len, head_dim//2] # Step 1: 转为复数 — ⚠️ 先升精度到 FP32# 复数乘法在 FP16 下极易产生 NaNLLaMA 源码强制 FP32xq_torch.view_as_complex(xq.float().reshape(*xq.shape[:-1],-1,2))xk_torch.view_as_complex(xk.float().reshape(*xk.shape[:-1],-1,2))# Step 2: 调整 freqs_cis 形状以广播# freqs_cis: [L, d/2] → [1, L, 1, d/2]freqs_cisreshape_for_broadcast(freqs_cis,xq_)# Step 3: 复数乘法旋转并转回实数xq_outtorch.view_as_real(xq_*freqs_cis).flatten(3)xk_outtorch.view_as_real(xk_*freqs_cis).flatten(3)# Step 4: 转回输入精度returnxq_out.type_as(xq),xk_out.type_as(xk)3.3 为什么必须 Upcast 到 FP32这是 RoPE 实现中最容易踩的坑。复数乘法(abi)×(cdi)(abi) \times (cdi)(abi)×(cdi)内部做了 4 次浮点乘法和 2 次加减。在 FP16 下精度只有 ~3.3 位有效十进制数字多次乘法后误差快速累积复数运算的误差传播更复杂实部虚部交叉影响极易产生 NaNLLaMA 官方源码强制在 FP32 下做 RoPE 旋转算完再转回 FP16/BF16。如果你漏了.float()模型在训练几万步后精度会悄悄退化。4. 工业实现对照4.1 LLaMA 源码的关键差异维度本教程复数法LLaMA 官方实数法实现方式view_as_complex 复数乘法手动 split cos/sin 交叉乘加可读性⭐⭐⭐⭐⭐ 数学直觉清晰⭐⭐ 代码冗长性能依赖 PyTorch 复数优化编译器更容易优化精度必须FP32也需要 FP32LLaMA 官方用实数法是因为编译器的复数支持在旧版本不够好但原理完全等价。4.2 上下文外推Context Extension—— 训练 4K推理 128K模型在 4K 序列训练推理时如何支持 16K这就是 RoPE Scaling 要解决的问题方法做法代表线性插值位置索引除以缩放因子m → m/sLLaMA 2 (32K)NTK-aware调大基频θ 10000 → 100000让高频减速Qwen (128K)YaRN高频插值 低频外推按维度组分别处理学术方案核心思想都是压缩位置空间或降低旋转速度让训练时见过的旋转角度能覆盖推理时的更长上下文。5. 踩坑记录5.1 忘记 Upcast 到 FP32现象训练几万步后 loss 逐渐发散或精度退化FP16 下更严重根因复数乘法在 FP16 下误差累积。(abi)*(cdi)的每一步乘加都在损失精度解决在view_as_complex前加.float()return 前.type_as(xq)转回5.2 旋转角度生成顺序写反现象模型完全不收敛perplexity 巨高根因torch.outer(t, freqs)和torch.outer(freqs, t)生成的角度矩阵形状不同[seq_len, d/2]vs[d/2, seq_len]导致广播到错误的维度解决torch.outer(t, freqs)— t 是位置行freqs 是频率列5.3 只对 Query 旋转没对 Key 旋转现象位置信息似乎没生效长序列效果与非位置模型差不多根因RoPE 必须同时旋转 Q 和 K内积⟨q,k⟩\langle q, k \rangle⟨q,k⟩的交叉项才会自然出现相对位置m−nm-nm−n解决apply_rotary_emb必须同时处理 xq 和 xk5.4flatten(3)的参数记错现象输出形状变成[B, L, H, d/2, 2]而不是[B, L, H, d]根因view_as_real在最后增加了一个维度 2实部/虚部需要flatten(3)合并回[d]解决flatten(3)表示从第 3 维0-indexed开始压平6. 延伸思考为什么不对 Value 也做 RoPE实验发现对 V 做旋转没有额外收益。因为 Attention 的位置感知只需要⟨qm,kn⟩\langle q_m, k_n \rangle⟨qm​,kn​⟩体现相对位置——输出是 V 的加权和V 本身不需要知道位置RoPE 和 ALiBi 的区别ALiBi 直接在 Attention score 上加一个相对位置 bias不修改 Q/K 本身更简单但表达力弱于 RoPETriton 融合实现在 Triton 中可以把整个 RoPE kernel 写成一个 fused kernel消除中间张量的显存读写与后续内容的关系RoPE 是 Attention 实现04_Attention_MHA_GQA的前置知识。理解 RoPE 后MHA/GQA 中 Q/K 的初始化和前向传播就顺理成章了RoPE 的优雅之处在于用复数旋转这个 200 年前的数学工具解决了大模型位置泛化这个 2023 年的工程难题。复数不是装饰品是实实在在的工程抓手。