PyTorch Transformer模块详解目录基础算子层核心模块层位置编码前馈网络完整架构模块学习路径建议输入输出形状速查表基础算子层torch.matmul功能说明矩阵乘法是注意力机制的核心操作用于计算查询矩阵(Q)和键矩阵(K)的点积。# 计算注意力分数attn_scorestorch.matmul(q,k.transpose(-2,-1))参数说明q: 查询矩阵形状[B, h, L_q, d]k: 键矩阵形状[B, h, L_k, d]k.transpose(-2, -1) : 形状[B, h, d,L_k]输出注意力分数矩阵形状[B, h, L_q, L_k]torch.softmax功能说明将注意力分数转换为概率分布使每个位置的注意力权重和为1。attn_weightstorch.softmax(attn_scores,dim-1)参数说明attn_scores: 注意力分数矩阵dim: 指定在哪个维度上进行softmax计算通常为-1最后一个维度输出注意力权重矩阵形状与输入相同最后一维和为1torch.masked_fill功能说明掩码操作用于屏蔽padding位置或未来信息将指定位置的值替换为指定值。attn_scoresattn_scores.masked_fill(mask0,-1e9)参数说明mask: 掩码矩阵0表示需要屏蔽的位置-1e9: 替换值通常使用一个很大的负数经过softmax后会趋近于0输出被掩码处理后的注意力分数矩阵torch.sqrt功能说明计算平方根用于缩放点积注意力防止点积结果过大导致梯度消失。scaletorch.sqrt(torch.tensor(head_dim,dtypetorch.float32))attn_scoresattn_scores/scale参数说明head_dim- 每个注意力头的维度输出缩放因子张量形状变换操作功能说明用于多头注意力的拆分与合并包括view、reshape、transpose、contiguous等操作。# 拆分多头qq.view(B,L,h,d).transpose(1,2)# [B, L, d_model] - [B, h, L, d]# 合并多头outout.transpose(1,2).contiguous().view(B,L,d_model)# [B, h, L, d] - [B, L, d_model]核心模块层nn.Linear功能说明全连接层用于Q/K/V/O投影和前馈网络中的线性变换。self.w_qnn.Linear(d_model,d_model)参数说明in_features: 输入特征维度out_features: 输出特征维度bias: 是否使用偏置项默认为True输入[B, L, d_model]输出[B, L, d_model]nn.Dropout功能说明Dropout层用于防止过拟合在训练时随机将部分神经元输出置零。self.dropoutnn.Dropout(p0.1)参数说明p- 丢弃概率0.1表示10%的神经元被随机置零输入任意形状张量输出同形状张量训练时部分元素被置零nn.LayerNorm功能说明层归一化对每个样本的特征维度进行归一化加速训练并提高模型稳定性。self.normnn.LayerNorm(d_model)参数说明normalized_shape: 需要归一化的维度大小eps: 数值稳定性小量默认1e-5elementwise_affine: 是否学习缩放和平移参数默认True输入[B, L, d_model]输出[B, L, d_model]最后一维被归一化nn.Embedding功能说明词嵌入层将离散的token ID映射为连续的向量表示。self.embeddingnn.Embedding(vocab_size,d_model)参数说明num_embeddings: 词汇表大小embedding_dim: 嵌入向量维度输入[B, L]元素为token ID的LongTensor输出[B, L, d_model]位置编码PositionalEncoding功能说明由于Transformer没有循环或卷积结构需要显式添加位置信息。使用正弦和余弦函数生成位置编码。classPositionalEncoding(nn.Module):def__init__(self,d_model,max_len5000):super().__init__()petorch.zeros(max_len,d_model)positiontorch.arange(0,max_len).unsqueeze(1)div_termtorch.exp(torch.arange(0,d_model,2)*(-math.log(10000.0)/d_model))pe[:,0::2]torch.sin(position*div_term)pe[:,1::2]torch.cos(position*div_term)self.register_buffer(pe,pe)defforward(self,x):returnxself.pe[:x.size(1)]参数说明d_model: 模型维度max_len: 最大序列长度输入[B, L, d_model]输出[B, L, d_model]加上了位置编码信息前馈网络PositionwiseFeedForward功能说明位置级前馈网络对序列中每个位置独立进行相同的变换。classPositionwiseFeedForward(nn.Module):def__init__(self,d_model,d_ff,dropout0.1):super().__init__()self.fc1nn.Linear(d_model,d_ff)self.fc2nn.Linear(d_ff,d_model)self.dropoutnn.Dropout(dropout)self.activationnn.ReLU()defforward(self,x):returnself.fc2(self.dropout(self.activation(self.fc1(x))))参数说明d_model: 模型维度d_ff: 隐藏层维度通常为4*d_modeldropout: Dropout概率输入[B, L, d_model]输出[B, L, d_model]完整架构模块Encoder Layer结构组成多头自注意力 前馈网络每个子层后都有残差连接和层归一化。输入 X → MultiHeadAttention(QX, KX, VX) → Add(X) → LayerNorm → → FeedForward → Add → LayerNorm → 输出Decoder Layer结构组成三个子层掩码多头自注意力、多头注意力Encoder-Decoder Attention、前馈网络。输入 Y → Masked MultiHeadAttention(QY, KY, VY) → Add(Y) → LayerNorm → → MultiHeadAttention(QY, KEncoder输出, VEncoder输出) → Add → LayerNorm → → FeedForward → Add → LayerNorm → 输出整体架构Input → Embedding PositionalEncoding → N × EncoderLayer → Encoder输出 Target → Embedding PositionalEncoding → N × DecoderLayer → Linear → Softmax → 输出概率学习路径建议按以下顺序逐个实现由简到繁Scaled Dot-Product Attention缩放点积注意力Multi-Head Attention多头注意力Position-wise Feed Forward前馈网络Positional Encoding位置编码Encoder Layer编码器层Decoder Layer解码器层完整 Transformer拼接 Encoder Decoder输入输出形状速查表模块输入形状输出形状说明Embedding[B, L][B, L, d_model]Bbatch_size, Lseq_lenPositionalEncoding[B, L, d_model][B, L, d_model]添加位置信息MultiHeadAttentionQ/K/V: [B, L, d_model][B, L, d_model]多头注意力计算FeedForward[B, L, d_model][B, L, d_model]位置级前馈网络LayerNorm[B, L, d_model][B, L, d_model]层归一化Linear (vocab投影)[B, L, d_model][B, L, vocab_size]词汇表投影完整Transformer架构图Encoder Input → Embedding → PositionalEncoding → [MultiHeadAttention → Add Norm → FeedForward → Add Norm] × N → Encoder输出 Decoder Target → Embedding → PositionalEncoding → [MaskedMultiHeadAttention → Add Norm → MultiHeadAttention → Add Norm → FeedForward → Add Norm] × N → Linear → Softmax → 输出概率通过掌握以上模块您将能够从零开始实现完整的Transformer架构。建议按照学习路径逐步实现每完成一个模块都进行充分的测试验证。