1.Transformer注意力
注意力模块(Attention Module)是深度学习中一种重要的机制,旨在让模型在处理输入数据时能够动态地关注最重要的部分。它最初在自然语言处理(NLP)领域被提出,并逐渐扩展到计算机视觉、语音识别等多个领域。
核心思想:注意力机制的核心思想是模仿人类的注意力分配方式,即根据任务需求动态地关注输入数据的不同部分。通过计算权重,模型可以决定哪些部分的信息对当前任务更为重要。
主要类型:
1. 全局注意力(Global Attention): 关注整个输入序列的所有部分。适用于需要全局信息的任务,如机器翻译。
2. 局部注意力(Local Attention):只关注输入序列的一部分。适用于长序列任务,减少计算复杂度。
3. 自注意力(Self-Attention):输入序列内部的元素相互关注。广泛应用于Transformer模型。
4. 多头注意力(Multi-Head Attention):通过多个注意力头并行处理输入,捕捉不同子空间的信息。增强模型的表达能力。
计算步骤:
1. 计算注意力分数:通过查询(Query)、键(Key)和值(Value)计算注意力分数,通常使用点积或加性注意力。
2. 归一化:使用Softmax函数将分数转换为概率分布。
3. 加权求和:根据归一化后的权重对值进行加权求和,得到输出。
2.ResNet+Transformer
将Transformer模块集成到ResNet中,通常是为了结合卷积神经网络(CNN)的局部特征提取能力和Transformer的全局建模能力。
这里添加的位置在最后的全连接层之前
代码如下:
import torch
import torch.nn as nn
import torchvision.models as modelsclass TransformerBlock(nn.Module):def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):super(TransformerBlock, self).__init__()self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)self.ffn = nn.Sequential(nn.Linear(embed_dim, ff_dim),nn.ReLU(),nn.Linear(ff_dim, embed_dim))self.dropout = nn.Dropout(dropout)def forward(self, x):attn_output, _ = self.attention(x, x, x)x = x + self.dropout(attn_output)x = self.norm1(x)ffn_output = self.ffn(x)x = x + self.dropout(ffn_output)x = self.norm2(x)return xclass ResNetWithTransformer(nn.Module):def __init__(self, num_classes, embed_dim=512, num_heads=8, ff_dim=2048, num_layers=2, dropout=0.1):super(ResNetWithTransformer, self).__init__()self.resnet = models.resnet18(pretrained=False)self.resnet.fc = nn.Identity() # Remove the final fully connected layerself.transformer = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])self.fc = nn.Linear(embed_dim, num_classes)def forward(self, x):x = self.resnet.conv1(x)x = self.resnet.bn1(x)x = self.resnet.relu(x)x = self.resnet.maxpool(x)x = self.resnet.layer1(x)x = self.resnet.layer2(x)x = self.resnet.layer3(x)x = self.resnet.layer4(x)x = self.resnet.avgpool(x)x = torch.flatten(x, 1)x = x.unsqueeze(0) # Add sequence dimensionx = self.transformer(x)x = x.squeeze(0) # Remove sequence dimensionx = self.fc(x)return x# Example usage
model = ResNetWithTransformer(num_classes=5)
print(model)input_tensor = torch.randn(1, 3, 224, 224)
output = model(input_tensor)
print(output.shape) # Should print torch.Size([1, 10])
代码说明
-
TransformerBlock:
-
实现了Transformer的一个基本块,包括多头注意力机制和前馈神经网络。
-
使用了LayerNorm和Dropout来稳定训练。
-
-
ResNetWithTransformer:
-
使用预训练的ResNet50作为特征提取器。
-
移除了ResNet的最后一层全连接层,替换为自定义的Transformer模块。
-
Transformer模块处理ResNet提取的特征,最后通过一个全连接层输出分类结果。
-
-
前向传播:
-
输入图像经过ResNet的卷积层和池化层提取特征。
-
特征被展平并通过Transformer模块进行全局建模。
-
最终通过全连接层输出分类结果。
-
注意事项
-
输入尺寸:代码假设输入图像尺寸为224x224,可根据实际需求调整。
-
Transformer参数:可根据任务需求调整embed_dim、num_heads、ff_dim等参数。
-
训练:在实际训练中,可能需要调整学习率、优化器等超参数。
总结
通过将Transformer模块集成到ResNet中,可以结合CNN的局部特征提取能力和Transformer的全局建模能力,适用于需要同时捕捉局部和全局信息的任务。
输出如下:
ResNetWithTransformer(
(resnet): ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Identity()
)
(transformer): Sequential(
(0): TransformerBlock(
(attention): MultiheadAttention(
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(ffn): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): ReLU()
(2): Linear(in_features=2048, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
(1): TransformerBlock(
(attention): MultiheadAttention(
(out_proj): Linear(in_features=512, out_features=512, bias=True)
)
(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(ffn): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): ReLU()
(2): Linear(in_features=2048, out_features=512, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(fc): Linear(in_features=512, out_features=5, bias=True)
)
torch.Size([1, 5])