写目录
- 代码
- 代码解释
- 示例
- netron 可视化
- F.linear操作
代码
import torch
from torch import nn
import torch.nn.functional as F
import mathclass MoEGate(nn.Module):def __init__(self, dim: int,n_routed_experts: int = 4,num_experts_per_tok: int = 2,scoring_func: str = 'softmax',aux_loss_alpha: float = 0.1,seq_aux: bool = True,norm_topk_prob: bool = True):super().__init__()self.top_k = num_experts_per_tokself.n_routed_experts = n_routed_expertsself.scoring_func = scoring_funcself.alpha = aux_loss_alphaself.seq_aux = seq_auxself.norm_topk_prob = norm_topk_probself.gating_dim = dimself.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))self.reset_parameters()def reset_parameters(self) -> None:import torch.nn.init as initinit.kaiming_uniform_(self.weight, a=math.sqrt(5))def forward(self, hidden_states):bsz, seq_len, h = hidden_states.shapehidden_states = hidden_states.view(-1, h)logits = F.linear(hidden_states, self.weight, None)if self.scoring_func == 'softmax':scores = logits.softmax(dim=-1)else:raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)if self.top_k > 1 and self.norm_topk_prob:denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20topk_weight = topk_weight / denominatorif self.training and self.alpha > 0.0:scores_for_aux = scoresaux_topk = self.top_ktopk_idx_for_aux_loss = topk_idx.view(bsz, -1)if self.seq_aux:scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)ce.scatter_add_(1, topk_idx_for_aux_loss,torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alphaelse:mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)ce = mask_ce.float().mean(0)Pi = scores_for_aux.mean(0)fi = ce * self.n_routed_expertsaux_loss = (Pi * fi).sum() * self.alphaelse:aux_loss = 0return topk_idx, topk_weight, aux_loss
代码解释
详细解释 MoEGate 类的运行逻辑:
- 初始化阶段:
def __init__(self, dim: int, n_routed_experts: int = 4, num_experts_per_tok: int = 2, ...):
dim
: 输入向量维度n_routed_experts
: 专家数量num_experts_per_tok
: 每个 token 选择的专家数量- 创建权重矩阵:
self.weight = nn.Parameter(torch.empty((n_routed_experts, dim)))
- 前向传播过程:
def forward(self, hidden_states):
a. 输入处理:
bsz, seq_len, h = hidden_states.shape
hidden_states = hidden_states.view(-1, h) # 展平为 [batch_size * seq_len, hidden_dim]
b. 计算专家分数:
logits = F.linear(hidden_states, self.weight, None) # [batch_size * seq_len, n_experts]
scores = logits.softmax(dim=-1) # 转换为概率分布
c. 选择 top-k 专家:
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1)
- 为每个 token 选择得分最高的 k 个专家
topk_idx
: 专家的索引topk_weight
: 对应的权重
d. 权重归一化(可选):
if self.top_k > 1 and self.norm_topk_prob:topk_weight = topk_weight / (topk_weight.sum(...) + 1e-20)
e. 辅助损失计算(训练时):
-
序列级别辅助损失(seq_aux=True):
- 计算每个序列中专家的使用分布
- 使用 scatter_add_ 累积专家使用次数
- 目标是使专家使用均匀
-
Token 级别辅助损失(seq_aux=False):
- 计算每个 token 的专家使用分布
- 使用 one-hot 编码记录专家选择
- 同样目标是平衡专家使用
- 返回结果:
return topk_idx, topk_weight, aux_loss
topk_idx
: 选中的专家索引 [batch_size * seq_len, num_experts_per_tok]topk_weight
: 对应的权重 [batch_size * seq_len, num_experts_per_tok]aux_loss
: 辅助损失(用于训练时平衡专家使用)
这个机制的核心目的是:
- 动态为每个 token 选择最合适的专家
- 确保专家负载均衡(通过辅助损失)
- 实现计算资源的高效利用
示例
# 创建 MoEGate 实例
gate = MoEGate(dim=512, # 输入维度n_routed_experts=4, # 专家数量num_experts_per_tok=2 # 每个token选择的专家数量
)# 创建示例输入
batch_size = 2
seq_len = 10
hidden_dim = 512
x = torch.randn(batch_size, seq_len, hidden_dim) # 形状: [2, 10, 512]# 前向传播
topk_idx, topk_weight, aux_loss = gate(x)print(f"输入 x 的形状: {x.shape}") # [2, 10, 512]
print(f"专家索引 topk_idx 的形状: {topk_idx.shape}") # [20, 2] (batch_size * seq_len, num_experts_per_tok)
print(f"专家权重 topk_weight 的形状: {topk_weight.shape}") # [20, 2]
print(f"辅助损失 aux_loss: {aux_loss}") # 标量值# 详细查看输出内容
print("\n专家索引示例(每个token选择的专家):")
print(topk_idx.reshape(batch_size, seq_len, -1)[0, 0]) # 第一个样本第一个token选择的专家
print("\n专家权重示例(对应的权重):")
print(topk_weight.reshape(batch_size, seq_len, -1)[0, 0]) # 对应的权重
输入 x 的形状: torch.Size([2, 10, 512])
专家索引 topk_idx 的形状: torch.Size([20, 2])
专家权重 topk_weight 的形状: torch.Size([20, 2])
辅助损失 aux_loss: 0.10424095392227173专家索引示例(每个token选择的专家):
tensor([2, 1])专家权重示例(对应的权重):
tensor([0.6414, 0.3586], grad_fn=<SelectBackward0>)
netron 可视化
因为 MoEGate 模型的输出包含多个返回值( topk_idx , topk_weight , aux_loss ),而 ONNX 导出时遇到了问题。主要原因是其中一些输出没有与输入建立可追踪的数据依赖关系。故仅返回return topk_idx
class MoEGate(nn.Module):# ... 保持其他代码不变 ...def forward(self, hidden_states):# ... 保持其他代码不变 ...return topk_idx# 创建 MoEGate 实例
gate = MoEGate(dim=512, # 输入维度n_routed_experts=4, # 专家数量num_experts_per_tok=2 # 每个token选择的专家数量
)# 创建示例输入
batch_size = 2
seq_len = 10
hidden_dim = 512
x = torch.randn(batch_size, seq_len, hidden_dim) # 形状: [2, 10, 512]torch.onnx.export(gate, # 要导出的模型x, # 模型输入"moe_gate.onnx", # 输出文件名input_names=['input'], # 输入名称output_names=['output'], # 输出名称opset_version=17
)
F.linear操作
# 假设参数
n_routed_experts = 4 # 专家数量
gating_dim = 512 # 输入维度
batch_size = 2 # 批次大小
seq_len = 10 # 序列长度# 1. 创建权重矩阵
weight = nn.Parameter(torch.empty((n_routed_experts, gating_dim)))
print(f"权重矩阵维度: {weight.shape}") # [4, 512]# 2. 创建输入张量
hidden_states = torch.randn(batch_size, seq_len, gating_dim) # [2, 10, 512]
print(f"原始输入维度: {hidden_states.shape}") # [2, 10, 512]# 3. 重塑输入
hidden_states = hidden_states.view(-1, gating_dim) # [20, 512]
print(f"重塑后输入维度: {hidden_states.shape}") # [20, 512]# 4. 线性变换
logits = F.linear(hidden_states, weight) # [20, 4]
print(f"输出维度: {logits.shape}") # [20, 4]
权重矩阵维度: torch.Size([4, 512])
原始输入维度: torch.Size([2, 10, 512])
重塑后输入维度: torch.Size([20, 512])
输出维度: torch.Size([20, 4])
这个矩阵乘法的维度变化:
在 F.linear(hidden_states, weight)
中:
hidden_states
: [20, 512]weight
: [4, 512]
这是一个矩阵乘法运算:[20, 512] × [512, 4]
- 注意:
weight
在运算时会被转置,从 [4, 512] 变成 [512, 4] - 矩阵乘法规则:第一个矩阵的列数(512)必须等于第二个矩阵的行数(512)
- 结果维度:第一个矩阵的行数(20) × 第二个矩阵的列数(4)
用小矩阵举例说明:
# 简化的例子
h = torch.randn(3, 5) # [3, 5] 相当于 [20, 512]
w = torch.randn(2, 5) # [2, 5] 相当于 [4, 512]
out = F.linear(h, w) # [3, 2] 相当于 [20, 4]# 等价于
w_t = w.t() # [5, 2] 转置
out = h @ w_t # [3, 5] × [5, 2] = [3, 2]
所以 [20, 512] × [512, 4] = [20, 4],其中:
- 20 表示批次大小×序列长度(每个token一行)
- 4 表示每个token得到的专家数量的分数