05. LLaMA3 Block Tutorial代码笔记

📅 2026/6/25 14:00:39
05. LLaMA3 Block Tutorial代码笔记
背景这个教程实现了LLaMA 风格的 Transformer 解码器层所需的两个核心模块LlamaMLP使用 SwiGLU 激活函数的前馈网络。LlamaDecoderLayer组合注意力与 MLP并采用Pre-Norm 残差结构归一化放在子层之前。同时上方提供了两个占位类DummyRMSNorm和DummyAttention模拟 RMSNorm 和多头注意力简化版便于整体流程演示。TODO 1定义 SwiGLU 所需的三个线性层无偏置self.gate_projnn.Linear(hidden_size,intermediate_size,biasFalse)self.up_projnn.Linear(hidden_size,intermediate_size,biasFalse)self.down_projnn.Linear(intermediate_size,hidden_size,biasFalse)SwiGLU 是什么SwiGLU 是一种激活函数结合了门控线性单元GLU和 Swish/SiLU 激活。其公式为SwiGLU(x)(SiLU(x⋅Wgate)) ⊙ (x⋅Wup) \text{SwiGLU}(x) \big( \text{SiLU}(x \cdot W_{\text{gate}}) \big) \;\odot\; \big( x \cdot W_{\text{up}} \big)SwiGLU(x)(SiLU(x⋅Wgate​))⊙(x⋅Wup​)然后再通过一个输出投影矩阵Wdown W_{\text{down}}Wdown​将维度映射回hidden_size。在 LLaMA 模型中这三个投影都是无偏置的。补全解释gate_proj将输入从hidden_size映射到intermediate_size然后通过 SiLU 激活充当“门控”信号0 到 1 之间的软开关。up_proj同样将输入从hidden_size映射到intermediate_size提供“数值”信号。down_proj将门控与数值的逐元素乘积结果从intermediate_size再映射回hidden_size恢复原始维度以便加上残差连接。三个线性层都没有偏置biasFalse严格遵循 LLaMA 设计。TODO 2实现 SwiGLU 的前向传播returnself.down_proj(F.silu(self.gate_proj(x))*self.up_proj(x))分步拆解self.gate_proj(x)计算门控部分的线性投影结果形状为[batch, seq, intermediate_size]。F.silu(...)对门控结果施加 SiLU也叫 Swish激活函数siLU(u)u⋅σ(u)\text{siLU}(u) u \cdot \sigma(u)siLU(u)u⋅σ(u)。这使门控值变为平滑的非线性范围约为(-0.278, ∞)在正值区间提供类似 ReLU 的激活在负值区间有轻微负值有利于梯度流动。self.up_proj(x)计算数值部分的线性投影形状同样为[batch, seq, intermediate_size]。*将激活后的门控信号与数值信号进行逐元素相乘。这就是 GLU 的核心门控控制了哪些信息可以通过。self.down_proj(...)将乘积结果投影回hidden_size输出形状恢复为[batch, seq, hidden_size]。这种结构比传统 ReLU/GeLU 的 MLP 多了一组门控分支能更好地筛选信息已被 LLaMA 等模型采用。TODO 3实现 LLaMA 的 Pre-Norm 残差连接# --- Attention Block ---residualhidden_states hself.input_layernorm(hidden_states)hself.self_attn(h)hresidualh# --- MLP Block ---residualh outself.post_attention_layernorm(h)outself.mlp(out)outresidualoutreturnoutLLaMA 的 Pre-Norm 结构与原始 TransformerPost-Norm在子层之后做 LayerNorm不同LLaMA 使用Pre-Norm即在注意力或 MLP 之前进行归一化。这样做的好处是训练更稳定梯度更容易流动是当前大语言模型的主流选择。代码详解注意力块Attention Blockresidual hidden_states保存残差连接的分支即输入本身。h self.input_layernorm(hidden_states)先对输入做 RMSNorm 归一化。h self.self_attn(h)归一化后的张量送入注意力模块这里用DummyAttention模拟实际会有 RoPE、GQA 等。h residual h将注意力输出与残差相加实现恒等映射的旁路。MLP 块MLP Blockresidual h保存注意力块输出作为新的残差。out self.post_attention_layernorm(h)再次进行 RMSNorm 归一化在 MLP 之前。out self.mlp(out)通过 SwiGLU MLP 层。out residual out加上残差输出最终结果。结构顺序对应典型的 LLaMA 解码器层输入 → RMSNorm → 注意力 → 残差 → RMSNorm → MLP → 残差 → 输出这种 Pre-Norm 残差的组合既保证了深层网络的训练稳定性又让每个子层都有一条原始的“高速通道”传递信息。总结TODO 1 2实现了 LLaMA 的 SwiGLU MLP用三个无偏置线性层前向时用 SiLU 门控乘以上投影再下投影。TODO 3实现了 Pre-Norm 的残差解码器层分别对注意力和 MLP 做“归一化→运算→加残差”完整串联起一层的计算流。这些组件组合起来就是现代大语言模型如 LLaMA最基本的构成单元。