LlamaAttention源码解析
- 1. LlamaAttention 介绍
- 1.1 多头注意力机制
- 1.2 注意力的计算过程
- 2. LlamaAttention类 源码解析
- 3. LlamaAttention类 的优化
- 3.1 LlamaFlashAttention2
- 3.2 LlamaSdpaAttention
1. LlamaAttention 介绍
LlamaAttention 是 LLaMA 模型中负责实现自注意力机制的核心组件,其使用了多头自注意力(Multi-Head Self-Attention)机制,允许模型在不同的子空间中并行计算注意力,从而提高了对信息的表达能力。
1.1 多头注意力机制
多头注意力是一种在现代神经网络中广泛使用的机制,特别是在Transformer架构中。其结构如下:
图片参考来源:Attention Is All You Need
1.2 注意力的计算过程
- 计算每个头的Q、K、V:
Q i = X W i Q , K i = X W i K , V i = X W i V Q_i=XW_i^{Q}, K_i=XW_i^{K}, V_i=XW_i^{V} Qi=XWiQ,Ki=XWiK,Vi=XWiV - 计算每个头的注意力得分:
s c o r e s i = Q i K i T d k scores_i=\frac{Q_iK_i^{T}}{\sqrt{d_k} } scoresi=dkQiKiT
使用掩码mask(可选):
s c o r e s i = s c o r e s i + m a s k scores_i =scores_i +mask scoresi=scoresi+mask - 计算每个头的注意力权重并softmax归一化:
a t t e n t i o n _ w e i g h t s i = s o f t m a x ( s c o r e s i ) attention\_weights_i=softmax(scores_i) attention_weightsi=softmax(scoresi) - 计算每个头的加权和:
a t t n _ o u t p u t i = a t t e n t i o n _ w e i g h t s i V i attn\_output_i=attention\_weights_iV_i attn_outputi=attention_weightsiVi - 拼接所有注意力头并进行线性变换:
a t t n _ o u t p u t = c o n c a t ( a t t n _ o u t p u t 1 , a t t n _ o u t p u t 2 , . . . , a t t n _ o u t p u t h ) attn\_output=concat(attn\_output_1, attn\_output_2,...,attn\_output_h) attn_output=concat(attn_output1,attn_output2,...,attn_outputh)
f i n a l _ o u t p u t = a t t n _ o u t p u t W O final\_output=attn\_outputW^O final_output=attn_outputWO
2. LlamaAttention类 源码解析
源码地址:transformers/src/transformers/models/llama/modeling_llama.py
# -*- coding: utf-8 -*-
# @time: 2024/8/28 15:15
import math
import torch
import torch.nn.functional as Ffrom typing import Optional, Tuple
from torch import nn
from transformers import LlamaConfig, Cache
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv
from transformers.utils import logginglogger = logging.get_logger(__name__)class LlamaAttention(nn.Module):"""Multi-headed attention from 'Attention Is All You Need' paper"""def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):super().__init__()self.config = config # 获取配置对象self.layer_idx = layer_idx # 获取当前层的索引# 如果没有提供 layer_idx,会警告用户这可能在使用缓存时导致错误if layer_idx is None:logger.warning_once(f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will ""lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` ""when creating this class.")self.attention_dropout = config.attention_dropout # 从配置中获取注意力dropout的概率self.hidden_size = config.hidden_size # 获取隐藏层的维度大小self.num_heads = config.num_attention_heads # 获取注意力头的数量self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) # 获取每个注意力头的维度,如果没有指定,则默认等于隐藏层维度除以注意力头的数量self.num_key_value_heads = config.num_key_value_heads # 键和值的头的数量self.num_key_value_groups = self.num_heads // self.num_key_value_heads # 每组键/值头的数量self.max_position_embeddings = config.max_position_embeddings # 最大位置嵌入数量self.rope_theta = config.rope_theta # 旋转位置嵌入的角度参数self.is_causal = True # 标志这个注意力机制是因果的,即只考虑当前位置及其之前的位置# 定义线性投影层self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)# TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers)# 在 v4.45 版本中移除(RoPE 在模型中计算,而不是在解码器层中)# 定义旋转位置嵌入(RoPE)层self.rotary_emb = LlamaRotaryEmbedding(config=self.config)def forward(self,hidden_states: torch.Tensor, # 输入的隐藏状态attention_mask: Optional[torch.Tensor] = None, # 注意力掩码(可选)position_ids: Optional[torch.LongTensor] = None, # 位置id(可选)past_key_value: Optional[Cache] = None, # 缓存键和值(可选)output_attentions: bool = False, # 是否输出注意力权重use_cache: bool = False, # 是否使用缓存cache_position: Optional[torch.LongTensor] = None, # 缓存位置(可选)position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # 位置嵌入,将在v4.45中作为必选项**kwargs,) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:bsz, q_len, _ = hidden_states.size() # 获取批次大小和序列长度# --------------------------------1. Q K V的线性计算(多处理器和单处理器)-------------------------------------## 如果配置中启用了多处理器训练if self.config.pretraining_tp > 1:key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp # 计算每个处理器将处理的键值头和维度的切片大小# 将查询投影权重矩阵(self.q_proj.weight)按行拆分成多个切片,以便分配到不同的处理器上。每个切片的大小为 self.num_heads * self.head_dim 除以处理器数量query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) # 将键投影权重矩阵(self.k_proj.weight)按行拆分成多个切片,每个切片的大小为之前计算的 key_value_slicing。value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) # 将值投影权重矩阵(self.v_proj.weight)按行拆分成多个切片。每个切片的大小也为 key_value_slicing。# 多处理器环境下的注意力计算# 对每个处理器上的查询切片应用线性变换,将所有处理器的查询输出拼接在一起,形成一个完整的查询张量query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]query_states = torch.cat(query_states, dim=-1)# 对每个处理器上的键切片应用线性变换,将所有处理器的键输出拼接在一起,形成一个完整的键张量key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]key_states = torch.cat(key_states, dim=-1)# 对每个处理器上的值切片应用线性变换,将所有处理器的值输出拼接在一起,形成一个完整的值张量value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]value_states = torch.cat(value_states, dim=-1)else: # 单处理器,正常的注意力计算query_states = self.q_proj(hidden_states) # 计算querykey_states = self.k_proj(hidden_states) # 计算keyvalue_states = self.v_proj(hidden_states) # 计算value# --------------------------------2. 调整Q K V的size, 适应多头注意力的维度格式-------------------------------------## 调整query、key和value的形状,使它们符合多头注意力的格式# 具体维度变化为:[bsz, q_len, num_heads * head_dim] -> [bsz, q_len, num_heads, head_dim] -> [bsz, num_heads, q_len, head_dim]query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)# --------------------------------3. 为Q K添加位置编码-------------------------------------## 如果没有提供位置嵌入,计算旋转位置嵌入;否则,直接使用位置嵌入if position_embeddings is None:logger.warning_once("The attention layers in this model are transitioning from computing the RoPE embeddings internally ""through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed ""`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be ""removed and `position_embeddings` will be mandatory.")cos, sin = self.rotary_emb(value_states, position_ids)else:cos, sin = position_embeddings# apply_rotary_pos_emb 函数通过旋转位置编码对查询和键张量进行增强query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)# 如果提供了缓存的键和值,更新缓存中的键和值if past_key_value is not None:# sin and cos are specific to RoPE models; cache_position needed for the static cachecache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)# --------------------------------4. 再次调整Q K 的size, 适应注意力头的数量-------------------------------------## 调整键(key)和值(value)张量的维度,以适应模型的注意力头数量。key_states = repeat_kv(key_states, self.num_key_value_groups)value_states = repeat_kv(value_states, self.num_key_value_groups)# --------------------------------5. 自注意力的计算-------------------------------------## 5.1 计算查询和键的点积,并进行缩放attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)# 如果提供了注意力掩码,将因果掩码加到注意力权重上if attention_mask is not None: # no matter the length, we just slice itcausal_mask = attention_mask[:, :, :, : key_states.shape[-2]]attn_weights = attn_weights + causal_mask# upcast attention to fp32# 5.2 计算softmax归一化后的注意力权重attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)# 5.3 将注意力权重与值张量进行矩阵乘法,以生成注意力输出attn_output = torch.matmul(attn_weights, value_states)# --------------------------------6. 自注意力size的调整和结果输出-------------------------------------## 检查输出的size是否正确if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"f" {attn_output.size()}")# 调整注意力输出的sizeattn_output = attn_output.transpose(1, 2).contiguous()# 将多头输出拼接回原始维度attn_output = attn_output.reshape(bsz, q_len, -1)# 如果是多处理器训练if self.config.pretraining_tp > 1:attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) # 将 attn_output 张量沿着 dim=2 (特征维度) 拆分成多个片段,以适应分片训练的设置o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) # 将输出投影权重 (o_proj.weight) 拆分成多个片段,以与 attn_output 对应attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) # 对拆分的 attn_output 和权重片段进行线性变换并汇总else:attn_output = self.o_proj(attn_output) # 通过线性投影层得到最终输出# 如果不需要输出注意力权重,将注意力权重设置为 Noneif not output_attentions:attn_weights = None# 返回最终的注意力输出、注意力权重和缓存的键值return attn_output, attn_weights, past_key_value
3. LlamaAttention类 的优化
3.1 LlamaFlashAttention2
LlamaFlashAttention2 是对 LlamaAttention 的一种优化实现,主要用于提高计算效率。它继承了 LlamaAttention 的所有权重和结构,但在前向传播过程中调用了 Flash Attention 的实现。Flash Attention 是一种高效的注意力计算方法,通过使用特定的优化技术(如顶部左对齐掩码或滑动窗口),能显著减少内存占用和计算时间。以下是 LlamaFlashAttention2 的几个关键特点:
- 高效计算:利用 Flash Attention 提升计算效率,特别是在处理长序列时。
- 动态掩码:支持变长序列和填充标记的处理,通过对填充标记进行适当的处理来提高精度。
- 适配性:根据 Flash Attention 的版本调整参数,比如是否使用顶部左对齐掩码(
use_top_left_mask
)。
代码片段:
源码地址:transformers/src/transformers/models/llama/modeling_llama.py
attn_output = _flash_attention_forward(query_states,key_states,value_states,attention_mask,q_len,position_ids=position_ids,dropout=dropout_rate,sliding_window=getattr(self, "sliding_window", None),use_top_left_mask=self._flash_attn_uses_top_left_mask,is_causal=self.is_causal,
)
3.2 LlamaSdpaAttention
LlamaSdpaAttention 是使用 torch.nn.functional.scaled_dot_product_attention
实现的注意力机制。它继承了 LlamaAttention,但在前向传播过程中适配了 SDPA(Static Dynamic Position-Aware Attention) API。以下是 LlamaSdpaAttention 的几个关键特点:
- 使用 SDPA:通过
torch.nn.functional.scaled_dot_product_attention
实现,高效计算注意力分数。 - 兼容性:适配了 SDPA API 的变化和特性,如处理填充标记和位置编码的不同方式。
- 优化注意力计算:在处理包含填充标记的序列时,通过预处理和掩码调整来提高效率。
代码片段:
源码地址:transformers/src/transformers/models/llama/modeling_llama.py
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states,key_states,value_states,attn_mask=causal_mask,dropout_p=self.attention_dropout if self.training else 0.0,is_causal=is_causal,
)
总的来说,LlamaFlashAttention2通过引入 Flash Attention 的优化技术,提高了长序列的计算效率和处理能力,特别适合需要高效处理大规模数据的场景。LlamaSdpaAttention结合 SDPA API 提供了高效的注意力计算,支持在处理填充标记和不同位置编码时的优化,适用于需要精确和高效位置感知的任务。