1.网络结构介绍
DenseNet最大的特点是对相同大小的特征图来说,每一层都与前馈层和后序层相连,以及两层之间是拼接起来的而不是简单的相加。该网络主要由Dense块和Transition层组成。
结构介绍:
- 密集连接:每层都和前馈层和后面的层连接。
- 特征复用:密集机制使得每层都可以直接访问前面所有层的特征,从而减少了特征冗余。每层生成的特征被用于后续所有层,因此模型可以更有效地学习。
- 增长率:每层新增的特征图数量由增长率k决定。
- 过度层:在Dense块之间,主要用于降低通道数。
下图是DenseNet的网络结构:
DenseNet的主要组成部分
1.初始卷积层:通常是7×7卷积层,后接BN、ReLU和最大池化
2.Dense Block:包含多个卷积层,每层输出与前面所有层的输出进行级联。
3.Transition Layer:位于Dense Block之间,用于调整特征图数量并进行下采样。
4.Global Average Pooling:最后一层,用全局平局池化层代替全连接层,减少参数量,提高泛化能力。
2. 代码复现
这里我选用的L=40,k=12,使用4个DenseBlock,每个DenseBlock中有10个卷积,初始通道数变为24,增长率为12
稠密块如下:
# 稠密块
def conv_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels),nn.ReLU(),nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(out_channels),nn.ReLU6(),nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1))return blkclass DenseBlock(nn.Module):def __init__(self, num_conv, in_channels, out_channels) -> None:super().__init__()blk = []for i in range(num_conv):inc = in_channels + i * out_channelsblk.append(conv_block(inc, out_channels))self.net = nn.ModuleList(blk)self.out_channels = in_channels + num_conv * out_channelsdef forward(self, X):for blk in self.net:Y = blk(X)#print(X.shape ,blk)X = torch.cat((X,Y), dim=1)return X
过渡层如下:
class Transition(nn.Module):def __init__(self, in_channels, out_channels) -> None:super().__init__()self.bn1 = nn.BatchNorm2d(in_channels)self.relu1 = nn.ReLU()self.cv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)self.avg = nn.AvgPool2d(kernel_size=2, stride=2,padding=0)out_channelsdef forward(self, X):X1 = self.cv1(self.relu1(self.bn1(X)))X2 = self.avg(X1)return X2
DenseNet模型
# DenseNet模型
net = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=24, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(24),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)num_channels, growth_rate = 24, 12
num_conv_in_denseblocks = [10, 10, 10, 10]
for i, num_conv in enumerate(num_conv_in_denseblocks):DB = DenseBlock(num_conv, num_channels, growth_rate)net.add_module("DenseBlock_%d" % i, DB)# 上一个稠密块的输出通道数num_channels = DB.out_channels# 在稠密块之间添加通道数减半的过渡层if i != len(num_conv_in_denseblocks)-1:TR = Transition(num_channels, num_channels//2)num_channels = num_channels//2net.add_module("Transition_%d"%i, TR)# 全局平均池化层
class Global_Avg_Pooling(nn.Module):def __init__(self):super().__init__()def forward(self, X):return nn.functional.avg_pool2d(X, X.size()[2:])# 全连接层
class FlattenLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):return X.view(X.size()[0], -1)net.add_module("global_avg_pool", Global_Avg_Pooling())
net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(num_channels, 10)))
创建一个四维变量来查看网络结构,输出结果如下图:
X = torch.rand(1, 1, 224, 224)
for name, m in net.named_children():X = m(X)print(name, '->',X.shape)