完整代码来源:DiT
DiT模型主要是在diffusion中,使用transformer模型替换了UNet模型,使用class来控制图像生成。
根据论文,模型越大,patch size 越小,FID越小。
模型越大,参数越多,patch size越小,参与计算的信息就越多,模型效果越好。
模型使用了Imagenet 训练,有1000个分类,class_labels是0到999的整数,无条件类则必须是1000,在class embedding的时候定义了 nn.Embedding(1000+1,1152)。
在模型初始化的时候,DiTBlock和FinalLayer的参数中weight和bias被置为0,也就是adaLN_zero。
在DiTBlock 中,参数 c = t + y ,也就是 class embedding和 t embedding的和,然后通过线性映射生成shift,scale,和gate。
def modulate(x, shift, scale):return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)class DiTBlock(nn.Module):"""A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning."""def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):super().__init__()self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)mlp_hidden_dim = int(hidden_size * mlp_ratio)approx_gelu = lambda: nn.GELU(approximate="tanh")self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)self.adaLN_modulation = nn.Sequential(nn.SiLU(),nn.Linear(hidden_size, 6 * hidden_size, bias=True))def forward(self, x, c):shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))return x# Zero-out adaLN modulation layers in DiT blocks:for block in self.blocks:nn.init.constant_(block.adaLN_modulation[-1].weight, 0)nn.init.constant_(block.adaLN_modulation[-1].bias, 0)# Zero-out output layers:nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)nn.init.constant_(self.final_layer.linear.weight, 0)nn.init.constant_(self.final_layer.linear.bias, 0)