LLM 推理加速:从算子融合到投机解码的工程实践

📅 2026/6/16 8:48:51
LLM 推理加速:从算子融合到投机解码的工程实践
LLM 推理加速从算子融合到投机解码的工程实践一、延迟瓶颈内存带宽而非算力大模型推理的延迟主要卡在四个环节数据搬运权重从 HBM 加载、计算矩阵乘和注意力、KV Cache 管理历史 Token 读写以及调度开销请求排队。实际部署中真正的瓶颈往往是内存带宽而非计算能力。以 A100-80G 为例其 FP16 峰值算力达 312 TFLOPS但 HBM 带宽仅为 2TB/s。一个 7B 模型单次前向传播计算量约 14GFLOP耗时 0.045ms但读取权重需 14GB耗时约 7ms。数据搬运耗时是计算的 155 倍。这就是典型的“内存墙”推理性能受限于带宽算力大部分时间在空转等待数据。加速的核心思路很直接减少内存访问算子融合、KV Cache 优化、提高计算密度连续批处理、投机解码、降低精度量化。具体选哪种得看业务对延迟和精度的容忍度。二、技术栈分层flowchart TB subgraph 模型层优化 Q[模型量化: FP16→INT8/INT4] -- Q1[显存减少50-75%] Q -- Q2[带宽需求降低] GQA[GQA/MQA: 共享KV头] -- GQA1[KV Cache减少4-8x] end subgraph 算子层优化 FUSE[算子融合: Flash Attention] -- FUSE1[减少HBM访问次数] FUSE -- FUSE2[单次前向: 7ms→2ms] KV[KV Cache分页: PagedAttention] -- KV1[显存利用率95%] end subgraph 调度层优化 CB[连续批处理: Continuous Batching] -- CB1[吞吐量提升2-3x] SD[投机解码: Speculative Decoding] -- SD1[延迟降低30-50%] PD[前缀缓存: Prefix Caching] -- PD1[重复Prompt零计算] end subgraph 系统层优化 CB1 -- THROUGHPUT[吞吐量优化] SD1 -- LATENCY[延迟优化] Q2 -- COST[成本优化] end style FUSE fill:#e3f2fd style CB fill:#fff3e0 style SD fill:#e8f5e9 style Q fill:#fce4ec优化通常按模型、算子、调度、系统四个层面展开。模型层解决显存和带宽量化、GQA算子层解决计算效率Flash Attention、PagedAttention调度层解决并发连续批处理、投机解码系统层解决资源复用前缀缓存。各层优化可独立生效组合使用效果更明显。三、核心工程实现3.1 连续批处理Continuous Batching传统静态批处理必须等所有请求生成完毕才能释放显存而连续批处理在每个迭代步动态调整批次完成的请求立即移出新请求立即加入。# continuous_batching.py — 连续批处理调度器 import time from dataclasses import dataclass, field from typing import Optional from collections import deque dataclass class InferenceRequest: 推理请求 request_id: str prompt_tokens: list[int] max_output_tokens: int 256 temperature: float 0.7 # 运行时状态 generated_tokens: list[int] field(default_factorylist) is_completed: bool False arrival_time: float field(default_factorytime.time) first_token_time: Optional[float] None class ContinuousBatcher: 连续批处理调度器 def __init__(self, max_batch_size: int 32, max_waiting_queue: int 1000, scheduling_policy: str fcfs): self._max_batch_size max_batch_size self._max_waiting_queue max_waiting_queue self._scheduling_policy scheduling_policy self._waiting_queue: deque[InferenceRequest] deque() self._running_batch: list[InferenceRequest] [] self._completed_requests: list[InferenceRequest] [] def submit(self, request: InferenceRequest) - bool: 提交推理请求 if len(self._waiting_queue) self._max_waiting_queue: return False self._waiting_queue.append(request) return True def step(self, model_step_fn) - list[InferenceRequest]: 执行一个推理步骤 # 1. 移除已完成的请求 completed [req for req in self._running_batch if req.is_completed] self._running_batch [req for req in self._running_batch if not req.is_completed] self._completed_requests.extend(completed) # 2. 补充新请求到批次 available_slots self._max_batch_size - len(self._running_batch) while available_slots 0 and self._waiting_queue: if self._scheduling_policy fcfs: request self._waiting_queue.popleft() elif self._scheduling_policy sjf: shortest min(self._waiting_queue, keylambda r: r.max_output_tokens) self._waiting_queue.remove(shortest) request shortest else: request self._waiting_queue.popleft() self._running_batch.append(request) available_slots - 1 # 3. 执行前向传播 if self._running_batch: model_step_fn(self._running_batch) for req in self._running_batch: if req.first_token_time is None: req.first_token_time time.time() if len(req.generated_tokens) req.max_output_tokens: req.is_completed True return completed def get_stats(self) - dict: 获取调度器统计信息 return { waiting_queue_size: len(self._waiting_queue), running_batch_size: len(self._running_batch), completed_count: len(self._completed_requests), utilization: round(len(self._running_batch) / self._max_batch_size, 2) if self._max_batch_size 0 else 0, }3.2 投机解码Speculative Decoding用小模型Draft Model快速生成 K 个候选 Token大模型Target Model一次性验证。只有被大模型接受的 Token 才计入最终结果。# speculative_decoding.py — 投机解码实现 import time from dataclasses import dataclass from typing import Optional dataclass class SpeculativeConfig: 投机解码配置 draft_model_name: str qwen2-0.5b target_model_name: str qwen2-7b speculative_length: int 5 temperature: float 0.7 class SpeculativeDecoder: 投机解码器 加速比 1 / (1 - 接受率) 当接受率为 80% 时理论加速比约 2.5x def __init__(self, draft_model_fnNone, target_model_fnNone, config: SpeculativeConfig None): self._draft_fn draft_model_fn self._target_fn target_model_fn self._config config or SpeculativeConfig() self._accept_stats { total_tokens: 0, accepted_tokens: 0, } def generate(self, prompt_tokens: list[int], max_tokens: int 256) - dict: 执行投机解码生成 generated [] total_draft_tokens 0 total_accepted 0 total_target_calls 0 while len(generated) max_tokens: # Step 1: 草稿模型快速生成 K 个候选 Token draft_tokens self._draft_generate( prompt_tokens generated, self._config.speculative_length, ) total_draft_tokens len(draft_tokens) # Step 2: 目标模型一次性验证 K1 个位置 verify_result self._target_verify( prompt_tokens generated, draft_tokens, ) total_target_calls 1 # Step 3: 处理验证结果 accepted_count verify_result[accepted_count] total_accepted accepted_count generated.extend(draft_tokens[:accepted_count]) # 从拒绝点采样或补充 bonus token if accepted_count len(draft_tokens): corrected_token verify_result.get(corrected_token) if corrected_token is not None: generated.append(corrected_token) else: bonus_token verify_result.get(bonus_token) if bonus_token is not None: generated.append(bonus_token) generated generated[:max_tokens] self._accept_stats[total_tokens] total_draft_tokens self._accept_stats[accepted_tokens] total_accepted accept_rate (total_accepted / total_draft_tokens if total_draft_tokens 0 else 0) return { generated_tokens: len(generated), total_draft_tokens: total_draft_tokens, accepted_tokens: total_accepted, accept_rate: round(accept_rate, 4), target_model_calls: total_target_calls, speedup_estimate: round(1 / (1 - accept_rate 0.1), 2), } def _draft_generate(self, context: list[int], num_tokens: int) - list[int]: 草稿模型生成候选 Token if self._draft_fn: return self._draft_fn(context, num_tokens) return list(range(100, 100 num_tokens)) def _target_verify(self, context: list[int], draft_tokens: list[int]) - dict: 目标模型验证候选 Token if self._target_fn: return self._target_fn(context, draft_tokens) import random accepted 0 for i in range(len(draft_tokens)): if random.random() 0.8: accepted 1 else: break return { accepted_count: accepted, corrected_token: 200 if accepted len(draft_tokens) else None, bonus_token: 300 if accepted len(draft_tokens) else None, } def get_accept_rate(self) - float: 获取历史平均接受率 total self._accept_stats[total_tokens] accepted self._accept_stats[accepted_tokens] return round(accepted / total, 4) if total 0 else 0四、精度代价与适用边界量化INT8 对 7B 模型精度影响通常在 0.5% 以内INT4 则在 1%-3%。对话生成等场景对 INT4 容忍度较高代码生成、数学推理等强逻辑任务建议保留 INT8 或 FP8。投机解码加速效果完全取决于草稿模型的接受率。如果接受率低于 60%验证开销会抵消生成收益反而变慢。同系列模型如 Qwen2-0.5B 配 Qwen2-7B输出分布接近接受率通常在 75%-85%效果最稳。连续批处理吞吐量上去了但尾部延迟可能增加。短请求若和长请求混批得等长请求跑完才能释放显存。解决办法是引入优先级调度或者按延迟要求分批次处理。前缀缓存缓存系统提示词等重复 Prompt 的 KV Cache 能省计算但会占显存。如果命中率低反而浪费资源。建议只缓存高频前缀并配上 LRU 淘汰策略。五、总结LLM 推理加速是全栈工程模型、算子、调度、系统四层都有优化空间。从投入产出比看Flash Attention 和连续批处理最值得优先落地。投机解码在“大小模型搭配”场景下效果明显但得先测接受率。量化是降低成本的直接手段INT8 风险低INT4 需评估业务容忍度。建议从 Flash Attention 连续批处理入手结合 pprof 数据决定是否引入投机解码和量化。每次优化后务必做基准测试用数据说话。