nano-vllm 用千行代码拆解 vLLM 核心是读懂大模型推理最快的捷径。1. 介绍上一篇把Linear的 weight 按输入维、输出维切完了——列切、行切、合并的双重切。但模型里还有两处不是普通Linear切法也不一样几万行的词表embed_tokens、lm_head以及注意力的多个 head。词表沿「词」切注意力沿「头」切。本篇补齐这两种切法张量并行的三种切法weight、vocab、head就齐了。importtorchfromtorchimportnnimporttorch.nn.functionalasF2. 总览weight 切上一篇讲过。本篇按切法组织先 vocab 切——词表[vocab, dim]沿词行切每卡只存一段词再 head 切——注意力的 head 各自独立沿头切每卡只算几个头。Qwen3-0.6B 在 tp2 下词表 151936→75968、q 头 16→8、kv 头 8→4。3. VocabParallelEmbeddingvocab 切词表是模型里最大的表Qwen3151936 × 1024。按词切到各卡每卡只存vocab/tp个词。__init__词数除以卡数本卡负责[vocab_start, vocab_end)这一段词号权重建成[vocab/tp, dim]。weight_loader沿词维dim 0从磁盘整份切出本卡那段词。forward里本卡只认自己那段词的 token认不出的输出 0最后各卡求和拼回。像把每个词交给专管它的图书管理员你报一个词号只有管它的人取得到书别人交白卷0把各人手里的叠起来就是你要的那本。五步mask标出落在本卡词号区间内的 token。x mask * (x - vocab_start)本卡的 token 平移到局部索引[0, vocab/tp)越界的被乘成 0先占位到第 0 行。F.embedding查本卡这段词表。mask.unsqueeze(1) * y把越界 token 占位的那一行清零。all_reduce求和每个 token 只有一张卡查得到、其余卡是 0相加就拼回完整 embedding。第 4 步为什么要unsqueeze(1)这一步要把越界 token 那一整行embedding 清零而mask形状是[N]每个 token 一个 0/1y是[N, dim]每个 token 一行——直接相乘维度对不上。unsqueeze(1)在第 1 维插一根长度为 1 的轴把mask变成[N, 1]相乘时它沿dim方向广播第 i 行整行乘以mask[i]命中行×1原样保留越界行×0整行归零。为什么求和能拼回一个 token 落在哪段词就只有存那段的卡查得到值别的卡全是 0叠加时 0 不影响结果正好得到那个 token 的 embedding。classVocabParallelEmbedding(nn.Module):def__init__(self,num_embeddings,embedding_dim,tp_size,tp_rank):super().__init__()self.tp_sizetp_size# 真实代码dist.get_world_size()self.tp_ranktp_rank# 真实代码dist.get_rank()pernum_embeddings//tp_size# 每卡词数self.vocab_start_idxper*tp_rank# 本卡词号区间 [start, end)self.vocab_end_idxself.vocab_start_idxper self.weightnn.Parameter(torch.empty(per,embedding_dim))self.weight.weight_loaderself.weight_loaderdefweight_loader(self,param,loaded_weight):# 沿词维(dim0)切出本卡那段词同上一篇列切shardparam.size(0)startself.tp_rank*shard param.data.copy_(loaded_weight.narrow(0,start,shard))defforward(self,x):mask(xself.vocab_start_idx)(xself.vocab_end_idx)# 只计算落到本卡的tokenxmask*(x-self.vocab_start_idx)# 本卡token→局部索引,越界→0yF.embedding(x,self.weight)ymask.unsqueeze(1)*y# 越界token那行清零# dist.all_reduce(y) # 真实代码:各卡求和; 单进程见第6章手动 y0y1returny# 合成词表[4,2]每行填可辨认常数两卡各 load 本卡词段fulltorch.arange(1,9,dtypetorch.float).reshape(4,2)e0VocabParallelEmbedding(4,2,tp_size2,tp_rank0)e1VocabParallelEmbedding(4,2,tp_size2,tp_rank1)e0.weight_loader(e0.weight,full)e1.weight_loader(e1.weight,full)print(rank0 存词,e0.vocab_start_idx,e0.vocab_end_idx)# [0, 2]toktorch.tensor([1,3])print(rank0 partial\n,e0(tok))# 只 token1 非零print(rank1 partial\n,e1(tok))# 只 token3 非零4. ParallelLMHeadvocab 切lm_head把每个位置的向量投影成每个词的 logits输出维就是词表大小。它继承VocabParallelEmbedding权重同样是[vocab/tp, dim]的本卡词段。forward本卡用自己那段词的权重算出本卡负责的vocab/tp个词的 logitsF.linear再gather到 rank0、沿词维cat拼成完整vocab维 logits。prefill 只取每条序列最后一位算 logits。同是 vocab 切为什么embed用all_reduce求和、lm_head用gather拼接embed切在输入侧token 索引落在词维。每个 token 只有一张卡查得到各卡输出形状相同、位置互补——求和拼回。lm_head切在输出侧logits 本身在词维。每卡算的是不同的词段rank0 算前一半词、rank1 算后一半各卡输出形状相同、内容不同——拼接才完整。classParallelLMHead(VocabParallelEmbedding):defforward(self,x):# prefill 取每条最后一位logitsF.linear(x,self.weight)# 本卡词段 logits [*, vocab/tp]# dist.gather → cat # 真实代码:收到rank0拼接returnlogits Wtorch.arange(1,9,dtypetorch.float).reshape(4,2)# [vocab4, dim2]h0ParallelLMHead(4,2,tp_size2,tp_rank0)h1ParallelLMHead(4,2,tp_size2,tp_rank1)h0.weight_loader(h0.weight,W)h1.weight_loader(h1.weight,W)xtorch.randn(3,2)print(rank0 logits 段,tuple(h0(x).shape),→ 词 0,1)print(rank1 logits 段,tuple(h1(x).shape),→ 词 2,3)rank0 logits 段 (3, 2) → 词 0,1 rank1 logits 段 (3, 2) → 词 2,35. attention 与 KV cachehead 切注意力的每个 head 独立计算、互不交互所以可以按 head 切到各卡每卡只算自己那几个头算的过程零通信。Qwen3Attention构造时把头数按卡数整除num_heads total_num_heads // tp_size16→8、num_kv_heads total_num_kv_heads // tp_size8→4。qkv_proj用上一篇的QKVParallelLinear输出的就是本卡这几个头的 q/k/v算完注意力o_proj用RowParallelLinear末尾all_reduce把各卡的输出按隐层维相加。KV cache 跟着只存本卡的头。allocate_kv_cache里num_kv_heads num_key_value_heads // world_sizecache 形状的头维就是本卡头数——本卡只算本卡的头也只需存本卡头的历史 k/v显存随卡数减半。整条注意力里唯一的跨卡通信是o_proj的all_reducehead 切本身不通信。# attention 按 head 切Qwen3Attention.__init__ 的核心total_num_heads,total_num_kv_heads,tp16,8,2asserttotal_num_heads%tp0andtotal_num_kv_heads%tp0num_headstotal_num_heads//tp# 每卡 q 头num_kv_headstotal_num_kv_heads//tp# 每卡 kv 头print(每卡 q 头,num_heads, kv 头,num_kv_heads)# KV cache 跟着只存本卡的头allocate_kv_cache 的形状layers,blocks,block_size,head_dim28,100,256,128fortin(1,2):kvhtotal_num_kv_heads//t shape(2,layers,blocks,block_size,kvh,head_dim)print(ftp{t}kv 头/卡{kvh}kv_cache 形状{shape})每卡 q 头 8 kv 头 4 tp1 kv 头/卡8 kv_cache 形状(2, 28, 100, 256, 8, 128) tp2 kv 头/卡4 kv_cache 形状(2, 28, 100, 256, 4, 128)6. 集成验证单进程构造 rank0、rank1 两卡手动合并两卡算出来的数据对比单卡数据检查计算是否正确。# embed 两卡求和(模拟 all_reduce)tabletorch.randn(4,2)a0VocabParallelEmbedding(4,2,2,0);a0.weight_loader(a0.weight,table)a1VocabParallelEmbedding(4,2,2,1);a1.weight_loader(a1.weight,table)idstorch.tensor([0,1,2,3])allreducea0(ids)a1(ids)print(embed all_reduce 单卡:,torch.allclose(allreduce,F.embedding(ids,table)))# lm_head 两卡cat(模拟 gather)Wftorch.randn(4,2)b0ParallelLMHead(4,2,2,0);b0.weight_loader(b0.weight,Wf)b1ParallelLMHead(4,2,2,1);b1.weight_loader(b1.weight,Wf)xxtorch.randn(3,2)gathertorch.cat([b0(xx),b1(xx)],dim-1)print(lmhead gather 单卡:,torch.allclose(gather,F.linear(xx,Wf)))embed all_reduce 单卡: True lmhead gather 单卡: True7. 小结至此张量并行的三种切法已经介绍完毕weight 切上一篇Linear沿隐层维列切、行切行切末尾all_reduce。vocab 切本篇词表沿词维切。embed切输入侧、各卡互补数据all_reduce求和拼回lm_head切输出侧、各卡不同词段gather拼接拼回。head 切本篇注意力沿头维切每卡算自己的头、零通信KV cache 也只存本卡头唯一通信在o_proj的all_reduce。切分都讲完了但代码到现在还是单进程模拟——真实的多卡靠多进程每进程占一张卡tp_size/tp_rank从dist取。进程怎么起、卡间怎么传方法调用将在下一篇介绍。