当前位置: 首页> 教育> 培训 > 【HuggingFace Transformers】LlamaAttention源码解析

【HuggingFace Transformers】LlamaAttention源码解析

时间:2025/8/25 17:20:40来源:https://blog.csdn.net/weixin_47936614/article/details/141820116 浏览次数:1次

LlamaAttention源码解析

  • 1. LlamaAttention 介绍
    • 1.1 多头注意力机制
    • 1.2 注意力的计算过程
  • 2. LlamaAttention类 源码解析
  • 3. LlamaAttention类 的优化
    • 3.1 LlamaFlashAttention2
    • 3.2 LlamaSdpaAttention

1. LlamaAttention 介绍

LlamaAttentionLLaMA 模型中负责实现自注意力机制的核心组件,其使用了多头自注意力(Multi-Head Self-Attention)机制,允许模型在不同的子空间中并行计算注意力,从而提高了对信息的表达能力。

1.1 多头注意力机制

多头注意力是一种在现代神经网络中广泛使用的机制,特别是在Transformer架构中。其结构如下:
在这里插入图片描述
图片参考来源:Attention Is All You Need

1.2 注意力的计算过程

  1. 计算每个头的Q、K、V:
    Q i = X W i Q , K i = X W i K , V i = X W i V Q_i=XW_i^{Q}, K_i=XW_i^{K}, V_i=XW_i^{V} Qi=XWiQ,Ki=XWiK,Vi=XWiV
  2. 计算每个头的注意力得分:
    s c o r e s i = Q i K i T d k scores_i=\frac{Q_iK_i^{T}}{\sqrt{d_k} } scoresi=dk QiKiT
    使用掩码mask(可选):
    s c o r e s i = s c o r e s i + m a s k scores_i =scores_i +mask scoresi=scoresi+mask
  3. 计算每个头的注意力权重并softmax归一化:
    a t t e n t i o n _ w e i g h t s i = s o f t m a x ( s c o r e s i ) attention\_weights_i=softmax(scores_i) attention_weightsi=softmax(scoresi)
  4. 计算每个头的加权和:
    a t t n _ o u t p u t i = a t t e n t i o n _ w e i g h t s i V i attn\_output_i=attention\_weights_iV_i attn_outputi=attention_weightsiVi
  5. 拼接所有注意力头并进行线性变换:
    a t t n _ o u t p u t = c o n c a t ( a t t n _ o u t p u t 1 , a t t n _ o u t p u t 2 , . . . , a t t n _ o u t p u t h ) attn\_output=concat(attn\_output_1, attn\_output_2,...,attn\_output_h) attn_output=concat(attn_output1,attn_output2,...,attn_outputh)
    f i n a l _ o u t p u t = a t t n _ o u t p u t W O final\_output=attn\_outputW^O final_output=attn_outputWO

2. LlamaAttention类 源码解析

源码地址:transformers/src/transformers/models/llama/modeling_llama.py

# -*- coding: utf-8 -*-
# @time: 2024/8/28 15:15
import math
import torch
import torch.nn.functional as Ffrom typing import Optional, Tuple
from torch import nn
from transformers import LlamaConfig, Cache
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
from transformers.utils import logginglogger = logging.get_logger(__name__)class LlamaAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):super().__init__()self.config = config  # 获取配置对象self.layer_idx = layer_idx  # 获取当前层的索引# 如果没有提供 layer_idx,会警告用户这可能在使用缓存时导致错误if layer_idx is None:logger.warning_once(f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will ""lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` ""when creating this class.")self.attention_dropout = config.attention_dropout  # 从配置中获取注意力dropout的概率self.hidden_size = config.hidden_size  # 获取隐藏层的维度大小self.num_heads = config.num_attention_heads  # 获取注意力头的数量self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)  # 获取每个注意力头的维度,如果没有指定,则默认等于隐藏层维度除以注意力头的数量self.num_key_value_heads = config.num_key_value_heads  # 键和值的头的数量self.num_key_value_groups = self.num_heads // self.num_key_value_heads  # 每组键/值头的数量self.max_position_embeddings = config.max_position_embeddings  # 最大位置嵌入数量self.rope_theta = config.rope_theta  # 旋转位置嵌入的角度参数self.is_causal = True  # 标志这个注意力机制是因果的,即只考虑当前位置及其之前的位置# 定义线性投影层self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)# 在 v4.45 版本中移除(RoPE 在模型中计算,而不是在解码器层中)# 定义旋转位置嵌入(RoPE)层self.rotary_emb = LlamaRotaryEmbedding(config=self.config)def forward(self,hidden_states: torch.Tensor,  # 输入的隐藏状态attention_mask: Optional[torch.Tensor] = None,  # 注意力掩码(可选)position_ids: Optional[torch.LongTensor] = None,  # 位置id(可选)past_key_value: Optional[Cache] = None,  # 缓存键和值(可选)output_attentions: bool = False,  # 是否输出注意力权重use_cache: bool = False,  # 是否使用缓存cache_position: Optional[torch.LongTensor] = None,  # 缓存位置(可选)position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # 位置嵌入,将在v4.45中作为必选项**kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:bsz, q_len, _ = hidden_states.size()  # 获取批次大小和序列长度# --------------------------------1. Q K V的线性计算(多处理器和单处理器)-------------------------------------## 如果配置中启用了多处理器训练if self.config.pretraining_tp > 1:key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp  # 计算每个处理器将处理的键值头和维度的切片大小# 将查询投影权重矩阵(self.q_proj.weight)按行拆分成多个切片,以便分配到不同的处理器上。每个切片的大小为 self.num_heads * self.head_dim 除以处理器数量query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)  # 将键投影权重矩阵(self.k_proj.weight)按行拆分成多个切片,每个切片的大小为之前计算的 key_value_slicing。value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)  # 将值投影权重矩阵(self.v_proj.weight)按行拆分成多个切片。每个切片的大小也为 key_value_slicing。# 多处理器环境下的注意力计算# 对每个处理器上的查询切片应用线性变换,将所有处理器的查询输出拼接在一起,形成一个完整的查询张量query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]query_states = torch.cat(query_states, dim=-1)# 对每个处理器上的键切片应用线性变换,将所有处理器的键输出拼接在一起,形成一个完整的键张量key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]key_states = torch.cat(key_states, dim=-1)# 对每个处理器上的值切片应用线性变换,将所有处理器的值输出拼接在一起,形成一个完整的值张量value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]value_states = torch.cat(value_states, dim=-1)else:  # 单处理器,正常的注意力计算query_states = self.q_proj(hidden_states)  # 计算querykey_states = self.k_proj(hidden_states)  # 计算keyvalue_states = self.v_proj(hidden_states)  # 计算value# --------------------------------2. 调整Q K V的size, 适应多头注意力的维度格式-------------------------------------## 调整query、key和value的形状,使它们符合多头注意力的格式# 具体维度变化为:[bsz, q_len, num_heads * head_dim] -> [bsz, q_len, num_heads, head_dim] -> [bsz, num_heads, q_len, head_dim]query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)# --------------------------------3. 为Q K添加位置编码-------------------------------------## 如果没有提供位置嵌入,计算旋转位置嵌入;否则,直接使用位置嵌入if position_embeddings is None:logger.warning_once("The attention layers in this model are transitioning from computing the RoPE embeddings internally ""through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ""`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be ""removed and `position_embeddings` will be mandatory.")cos, sin = self.rotary_emb(value_states, position_ids)else:cos, sin = position_embeddings# apply_rotary_pos_emb 函数通过旋转位置编码对查询和键张量进行增强query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)# 如果提供了缓存的键和值,更新缓存中的键和值if past_key_value is not None:# sin and cos are specific to RoPE models; cache_position needed for the static cachecache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)# --------------------------------4. 再次调整Q K 的size, 适应注意力头的数量-------------------------------------## 调整键(key)和值(value)张量的维度,以适应模型的注意力头数量。key_states = repeat_kv(key_states, self.num_key_value_groups)value_states = repeat_kv(value_states, self.num_key_value_groups)# --------------------------------5. 自注意力的计算-------------------------------------## 5.1 计算查询和键的点积,并进行缩放attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)# 如果提供了注意力掩码,将因果掩码加到注意力权重上if attention_mask is not None:  # no matter the length, we just slice itcausal_mask = attention_mask[:, :, :, : key_states.shape[-2]]attn_weights = attn_weights + causal_mask# upcast attention to fp32# 5.2 计算softmax归一化后的注意力权重attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)# 5.3 将注意力权重与值张量进行矩阵乘法,以生成注意力输出attn_output = torch.matmul(attn_weights, value_states)# --------------------------------6. 自注意力size的调整和结果输出-------------------------------------## 检查输出的size是否正确if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"f" {attn_output.size()}")# 调整注意力输出的sizeattn_output = attn_output.transpose(1, 2).contiguous()# 将多头输出拼接回原始维度attn_output = attn_output.reshape(bsz, q_len, -1)# 如果是多处理器训练if self.config.pretraining_tp > 1:attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)  # 将 attn_output 张量沿着 dim=2 (特征维度) 拆分成多个片段,以适应分片训练的设置o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)  # 将输出投影权重 (o_proj.weight) 拆分成多个片段,以与 attn_output 对应attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])  # 对拆分的 attn_output 和权重片段进行线性变换并汇总else:attn_output = self.o_proj(attn_output)  # 通过线性投影层得到最终输出# 如果不需要输出注意力权重,将注意力权重设置为 Noneif not output_attentions:attn_weights = None# 返回最终的注意力输出、注意力权重和缓存的键值return attn_output, attn_weights, past_key_value

3. LlamaAttention类 的优化

3.1 LlamaFlashAttention2

LlamaFlashAttention2 是对 LlamaAttention 的一种优化实现,主要用于提高计算效率。它继承了 LlamaAttention 的所有权重和结构,但在前向传播过程中调用了 Flash Attention 的实现。Flash Attention 是一种高效的注意力计算方法,通过使用特定的优化技术(如顶部左对齐掩码或滑动窗口),能显著减少内存占用和计算时间。以下是 LlamaFlashAttention2 的几个关键特点:

  • 高效计算:利用 Flash Attention 提升计算效率,特别是在处理长序列时。
  • 动态掩码:支持变长序列和填充标记的处理,通过对填充标记进行适当的处理来提高精度。
  • 适配性:根据 Flash Attention 的版本调整参数,比如是否使用顶部左对齐掩码(use_top_left_mask)。

代码片段
源码地址:transformers/src/transformers/models/llama/modeling_llama.py

attn_output = _flash_attention_forward(query_states,key_states,value_states,attention_mask,q_len,position_ids=position_ids,dropout=dropout_rate,sliding_window=getattr(self, "sliding_window", None),use_top_left_mask=self._flash_attn_uses_top_left_mask,is_causal=self.is_causal,
)

3.2 LlamaSdpaAttention

LlamaSdpaAttention 是使用 torch.nn.functional.scaled_dot_product_attention 实现的注意力机制。它继承了 LlamaAttention,但在前向传播过程中适配了 SDPA(Static Dynamic Position-Aware Attention) API。以下是 LlamaSdpaAttention 的几个关键特点:

  • 使用 SDPA:通过 torch.nn.functional.scaled_dot_product_attention 实现,高效计算注意力分数。
  • 兼容性:适配了 SDPA API 的变化和特性,如处理填充标记和位置编码的不同方式。
  • 优化注意力计算:在处理包含填充标记的序列时,通过预处理和掩码调整来提高效率。

代码片段
源码地址:transformers/src/transformers/models/llama/modeling_llama.py

attn_output = torch.nn.functional.scaled_dot_product_attention(query_states,key_states,value_states,attn_mask=causal_mask,dropout_p=self.attention_dropout if self.training else 0.0,is_causal=is_causal,
)

总的来说,LlamaFlashAttention2通过引入 Flash Attention 的优化技术,提高了长序列的计算效率和处理能力,特别适合需要高效处理大规模数据的场景。LlamaSdpaAttention结合 SDPA API 提供了高效的注意力计算,支持在处理填充标记和不同位置编码时的优化,适用于需要精确和高效位置感知的任务。

关键字:【HuggingFace Transformers】LlamaAttention源码解析

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: