当前位置: 首页> 科技> 能源 > Mamba-YOLO : 基于SSM的YOLO目标检测算法(附代码)

Mamba-YOLO : 基于SSM的YOLO目标检测算法(附代码)

时间:2025/9/6 20:45:54来源:https://blog.csdn.net/athrunsunny/article/details/141437558 浏览次数:0次

代码地址:GitHub - HZAI-ZJNU/Mamba-YOLO: the official pytorch implementation of “Mamba-YOLO:SSMs-based for Object Detection”

论文地址:https://arxiv.org/pdf/2406.05835

在深度学习技术的快速进步推动下,YOLO系列为实时目标检测器设立了新的基准。研究人员在YOLO的基础上,不断探索重新参数化、高效层聚合网络和anchor-free技术的创新应用。为了进一步提高检测性能,引入了基于Transformer的结构,显著扩展了模型的感受野,并实现了显著的性能增益。然而,这种改进是有代价的,因为自我注意机制的二次复杂度增加了模型的计算负担。幸运的是,状态空间模型(SSM)作为一种创新技术的出现有效地缓解了由二次复杂度引起的问题。鉴于这些进展,作者引入了一种新的基于SSM的目标检测模型Mamba-YOLO。Mamba-YOLO不仅优化了SSM基础,而且专门适用于目标检测任务。考虑到SSM在序列建模中的潜在局限性,如感受野不足和图像局部性弱,作者设计了LSBlock和RGBlock。这些模块能够更精确地捕获局部图像相关性,并显著增强模型的鲁棒性。在公开的基准数据集COCO和VOC上的大量实验结果表明,Mamba-YOLO在性能和竞争力方面都超过了现有的YOLO系列模型,展示了其巨大的潜力和竞争优势。


动机:为进一步提升目标检测性能,基于Transformer的网络检测结构被提出,引入Transformer显著扩大了模型的感受野,但增加了模型的计算负担。近期,基于状态空间模型(SSMs)的方法,如 Mamba,由于其强大的长距离依赖建模能力和线性时间复杂度的优势,为实现高效的目标检测提供了新思路。

创新点1:作者提出了Mamba-YOLO,以YOLOv8架构为基础,采用基于状态空间模型(SSM)的方法在目标检测方面建立了新的基准,为开发基于SSM的检测器提供了较好的基础。

创新点2:类比于YOLOv8的C2f模块,作者提出了ODSSBlock,ODSSBlock主要由LocalSpatial Block和ResGated Block模块组成,其中,LocalSpatial Block能够有效地提取输入特征图的局部空间信息,以补偿SSM的局部建模能力。通过重新思考MLP层的设计,作者提出了结合了门控聚合与有效卷积残差连接思想的ResGated Block,可有效地捕捉局部依赖关系并增强模型的鲁棒性。


Vision State Space Models

状态空间模型是近年来研究的热点。基于对SSM的研究,Mamba在输入大小上表现出线性复杂性,并解决了Transformer在建模状态空间的长序列上的计算效率问题。在广义视觉主干领域,Vision Mamba提出了一种基于SSM的纯视觉主干模型,标志着Mamba首次被引入视觉领域。VMamba引入了交叉扫描模块,使模型能够对2D图像进行选择性扫描增强视觉处理,并展示了在图像分类任务上的优势。LocalMamba专注于视觉模型的窗口扫描策略,优化视觉信息以捕获局部依赖关系,并引入动态扫描方法来搜索不同层的最佳选择。MambaOut探讨了Mamba架构在视觉任务中的必要性,指出SSM对于图像分类任务不是必要的,但它对于遵循长序列特征的检测和分割任务的价值值得进一步探索。在下游视觉任务中,Mamba 也被广泛应用于医学图像分割和遥感图像分割的研究。受VMamba在视觉任务领域取得的显著成果的启发,本文首次提出了Mamba YOLO,这是一种新的SSM模型,旨在考虑全局感觉场,同时展示其在目标检测任务中的潜力。


Preliminaries

源于状态空间模型(SSM)的结构化状态空间序列模型S4和Mamba都源于一个连续系统,该系统通过隐式潜在中间状态将单变量序列映射到输出序列中。这种设计不仅桥接了输入和输出之间的关系,而且封装了时间动态。该系统的数学定义如下: 

Mamba通过使用固定的离散化规则fA和fB将该连续系统应用于离散时间序列数据,以将参数A和B分别转换为其离散对应物,从而将系统更好地集成到深度学习架构中。用于此目的的常用判断方法是零阶保持(ZOH)。离散版本可以定义如下: 

转换后,模型通过线性递归形式进行计算,其定义如下:

整个序列变换也可以用卷积形式表示,其定义如下:


Overall Architecture

Mamba YOLO的体系结构概述如图2所示。作者的目标检测模型分为ODMamba主干部分和颈部部分。ODMamba由简单Stem、下采样block组成。在neck,遵循PAN-FPN的设计,使用ODSBlock模块而不是C2f来捕获更梯度丰富的信息流。主干首先通过Stem模块进行下采样,得到分辨率为H/4、W/4的2D特征图。因此,所有模型都由ODSBlock和VisionVue合并模块组成,用于进一步的下采样。在颈部,采用了PAFPN的设计,使用ODSSBlock代替C2f,其中Conv全权负责下采样。

Simple Stem:Modern Vision Transformers(ViTs)通常使用分割块作为其初始模块,将图像划分为不重叠的片段。该分割过程是通过核大小为4、步长为4的卷积运算来实现的。然而,最近的研究,如EfficientFormerV2的研究表明,这种方法可能会限制ViT的优化能力,影响整体性能。为了在性能和效率之间取得平衡,作者提出了一种精简的卷积层。使用两个步长为2、核大小为3的卷积,而不是使用不重叠的patches。同时为了保持速度,隐藏层通道数设置为输出的一半。

class SimpleStem(nn.Module):def __init__(self, inp, embed_dim, ks=3):super().__init__()self.hidden_dims = embed_dim // 2self.conv = nn.Sequential(nn.Conv2d(inp, self.hidden_dims, kernel_size=ks, stride=2, padding=autopad(ks, d=1), bias=False),nn.BatchNorm2d(self.hidden_dims),nn.GELU(),nn.Conv2d(self.hidden_dims, embed_dim, kernel_size=ks, stride=2, padding=autopad(ks, d=1), bias=False),nn.BatchNorm2d(embed_dim),nn.SiLU(),)def forward(self, x):return self.conv(x)

Vision Clue Merge:虽然卷积神经网络(CNNs)和视觉Transformer(ViT)结构通常使用卷积进行下采样,但作者发现这种方法会干扰SS2D在不同信息流阶段的选择性操作。为了解决这一问题,VMamba分割2D特征图,并使用1x1卷积来降低维度。作者的研究结果表明,为状态空间模型(SSM)保留更多的视觉线索有利于模型训练。与传统的尺寸减半相比,作者通过以下方式简化了这一过程:1)删除规范;2) 拆分维度图;3) 将多余的特征图附加到通道维度;4) 利用4倍压缩逐点卷积进行下采样。与使用步长为2的3x3卷积不同,作者的方法保留了SS2D从上一层选择的特征图。

class VisionClueMerge(nn.Module):def __init__(self, dim, out_dim):super().__init__()self.hidden = int(dim * 4)self.pw_linear = nn.Sequential(nn.Conv2d(self.hidden, out_dim, kernel_size=1, stride=1, padding=0),nn.BatchNorm2d(out_dim),nn.SiLU())def forward(self, x):y = torch.cat([x[..., ::2, ::2],x[..., 1::2, ::2],x[..., ::2, 1::2],x[..., 1::2, 1::2]], dim=1)return self.pw_linear(y)

这部分和yolov5的focus结构一样 


ODSS Block 

ODSSBlock是Mamba-YOLO的核心模块,该模块主要包含LSBlock、RGBlock、SS2D三个模块。在输入阶段经过一系列处理,可以使网络能够学习到更深入、更丰富的特征表示,同时通过批处理归一化保持训练推理过程的高效和稳定。ODSSBlock的批归一化、层归一化和残差连接设计允许模型在深层堆叠训练时有效流动。 

其中\widehat{\Phi }表示激活函数

其中LS表示LocalSpatial Block ,RG表示ResGated Block

Scan Expansion, S6 Block和Scan Merge是SS2D算法的三个主要步骤,其主要流程如图3所示。

Scan Expansion操作将输入图像扩展为一系列子图像,每个子图像表示特定的方向,并且当从对角视点观察时,Scan Expansion操作沿着四个对称方向进行处理,这四个方向分别是自上而下、自下而上、左右和单词从右到左。这样的布局不仅全面覆盖了输入图像的所有区域,而且通过系统的方向变换为后续的特征提取提供了丰富的多维信息库,从而提高了图像特征多维捕捉的效率和全面性。然后,在S6块操作中将这些子图像提交给特征提取,并且最后通过Scan Merge操作,将这些子图像合并在一起以形成与输入图像相同大小的输出图像。

S6 block操作如下图所示,主要是通过状态空间模型SSM来实现参数更新和学习,SSM 的作用是基于一个连续系统,将单变量序列 x(t) 映射到输出序列 y(t),通过隐含的中间状态 h(t) 实现输入和输出的关系。其中,参数矩阵A、B、C、D以及离散化时间间距Δ具有线性特性,参数矩阵A采用零阶保持方法离散化表示为eΔA,参数矩阵B、C以及离散化时间间距Δ在实现SSM模型时采用linear实现,参数矩阵D一般采用残差连接方式实现,在状态空间模型中D=1。此外,参数矩阵B根据一阶泰勒级数展开可近似为ΔB。 

class SS2D(nn.Module):def __init__(self,# basic dims ===========d_model=96,d_state=16,ssm_ratio=2.0,ssm_rank_ratio=2.0,dt_rank="auto",act_layer=nn.SiLU,# dwconv ===============d_conv=3,  # < 2 means no convconv_bias=True,# ======================dropout=0.0,bias=False,# ======================forward_type="v2",**kwargs,):"""ssm_rank_ratio would be used in the future..."""factory_kwargs = {"device": None, "dtype": None}super().__init__()d_expand = int(ssm_ratio * d_model)d_inner = int(min(ssm_rank_ratio, ssm_ratio) * d_model) if ssm_rank_ratio > 0 else d_expandself.dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rankself.d_state = math.ceil(d_model / 6) if d_state == "auto" else d_state  # 20240109self.d_conv = d_convself.K = 4# tags for forward_type ==============================def checkpostfix(tag, value):ret = value[-len(tag):] == tagif ret:value = value[:-len(tag)]return ret, valueself.disable_force32, forward_type = checkpostfix("no32", forward_type)self.disable_z, forward_type = checkpostfix("noz", forward_type)self.disable_z_act, forward_type = checkpostfix("nozact", forward_type)self.out_norm = nn.LayerNorm(d_inner)# forward_type debug =======================================FORWARD_TYPES = dict(v2=partial(self.forward_corev2, force_fp32=None, SelectiveScan=SelectiveScanCore),)self.forward_core = FORWARD_TYPES.get(forward_type, FORWARD_TYPES.get("v2", None))# in proj =======================================d_proj = d_expand if self.disable_z else (d_expand * 2)self.in_proj = nn.Conv2d(d_model, d_proj, kernel_size=1, stride=1, groups=1, bias=bias, **factory_kwargs)self.act: nn.Module = nn.GELU()# conv =======================================if self.d_conv > 1:self.conv2d = nn.Conv2d(in_channels=d_expand,out_channels=d_expand,groups=d_expand,bias=conv_bias,kernel_size=d_conv,padding=(d_conv - 1) // 2,**factory_kwargs,)# rank ratio =====================================self.ssm_low_rank = Falseif d_inner < d_expand:self.ssm_low_rank = Trueself.in_rank = nn.Conv2d(d_expand, d_inner, kernel_size=1, bias=False, **factory_kwargs)self.out_rank = nn.Linear(d_inner, d_expand, bias=False, **factory_kwargs)# x proj ============================self.x_proj = [nn.Linear(d_inner, (self.dt_rank + self.d_state * 2), bias=False,**factory_kwargs)for _ in range(self.K)]self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0))  # (K, N, inner)del self.x_proj# out proj =======================================self.out_proj = nn.Conv2d(d_expand, d_model, kernel_size=1, stride=1, bias=bias, **factory_kwargs)self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()# simple init dt_projs, A_logs, Dsself.Ds = nn.Parameter(torch.ones((self.K * d_inner)))self.A_logs = nn.Parameter(torch.zeros((self.K * d_inner, self.d_state)))  # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1self.dt_projs_weight = nn.Parameter(torch.randn((self.K, d_inner, self.dt_rank)))self.dt_projs_bias = nn.Parameter(torch.randn((self.K, d_inner)))@staticmethoddef dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,**factory_kwargs):dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)# Initialize special dt projection to preserve variance at initializationdt_init_std = dt_rank ** -0.5 * dt_scaleif dt_init == "constant":nn.init.constant_(dt_proj.weight, dt_init_std)elif dt_init == "random":nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)else:raise NotImplementedError# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_maxdt = torch.exp(torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))+ math.log(dt_min)).clamp(min=dt_init_floor)# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759inv_dt = dt + torch.log(-torch.expm1(-dt))with torch.no_grad():dt_proj.bias.copy_(inv_dt)# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit# dt_proj.bias._no_reinit = Truereturn dt_proj@staticmethoddef A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):# S4D real initializationA = repeat(torch.arange(1, d_state + 1, dtype=torch.float32, device=device),"n -> d n",d=d_inner,).contiguous()A_log = torch.log(A)  # Keep A_log in fp32if copies > 0:A_log = repeat(A_log, "d n -> r d n", r=copies)if merge:A_log = A_log.flatten(0, 1)A_log = nn.Parameter(A_log)A_log._no_weight_decay = Truereturn A_log@staticmethoddef D_init(d_inner, copies=-1, device=None, merge=True):# D "skip" parameterD = torch.ones(d_inner, device=device)if copies > 0:D = repeat(D, "n1 -> r n1", r=copies)if merge:D = D.flatten(0, 1)D = nn.Parameter(D)  # Keep in fp32D._no_weight_decay = Truereturn Ddef forward_corev2(self, x: torch.Tensor, channel_first=False, SelectiveScan=SelectiveScanCore,cross_selective_scan=cross_selective_scan, force_fp32=None):force_fp32 = (self.training and (not self.disable_force32)) if force_fp32 is None else force_fp32if not channel_first:x = x.permute(0, 3, 1, 2).contiguous()if self.ssm_low_rank:x = self.in_rank(x)x = cross_selective_scan(x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,self.A_logs, self.Ds,out_norm=getattr(self, "out_norm", None),out_norm_shape=getattr(self, "out_norm_shape", "v0"),delta_softplus=True, force_fp32=force_fp32,SelectiveScan=SelectiveScan, ssoflex=self.training,  # output fp32)if self.ssm_low_rank:x = self.out_rank(x)return xdef forward(self, x: torch.Tensor, **kwargs):x = self.in_proj(x)if not self.disable_z:x, z = x.chunk(2, dim=1)  # (b, d, h, w)if not self.disable_z_act:z1 = self.act(z)if self.d_conv > 0:x = self.conv2d(x)  # (b, d, h, w)x = self.act(x)y = self.forward_core(x, channel_first=(self.d_conv > 1))y = y.permute(0, 3, 1, 2).contiguous()if not self.disable_z:y = y * z1out = self.dropout(self.out_proj(y))return out

Local Spatial Block

Mamba 体系结构已被证明在捕获远程地面依赖性方面是有效的。然而,在处理涉及复杂尺度变化的任务时,它在提取局部特征方面面临一定的挑战。在图4(c)中,本文提出了LocalSpatial Block来增强对局部特征的捕获。具体而言,对于给定的输入特征F^{l-1},它首先进行深度可分离卷积,该卷积在不混合信道信息的情况下单独地对每个输入信道进行操作。有效提取输入特征图的局部空间信息,同时降低计算成本和参数数量,然后进行批量归一化,在减少过拟合的同时提供一定程度的正则化效果,得到的中间状态F^{l-1}定义为:

中间状态F^{l-1}通过1×1卷积混合通道信息,并通过激活函数更好地保持信息的分布,使模型能够学习更复杂的特征表示,这些特征表示能够从输入特征图中提取丰富的多尺度上下文信息。在LSBlock中,激活函数使用非线性GELU来改变特征的通道数量,而不改变空间维度,从而增强特征表示。最后,通过残差拼接将原始输入与处理后的特征融合。使模型能够理解和集成图像中不同维度的特征,从而提高对比例变化的鲁棒性。 

class LSBlock(nn.Module):def __init__(self, in_features, hidden_features=None, act_layer=nn.GELU, drop=0):super().__init__()self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=3, padding=3 // 2, groups=hidden_features)self.norm = nn.BatchNorm2d(hidden_features)self.fc2 = nn.Conv2d(hidden_features, hidden_features, kernel_size=1, padding=0)self.act = act_layer()self.fc3 = nn.Conv2d(hidden_features, in_features, kernel_size=1, padding=0)self.drop = nn.Dropout(drop)def forward(self, x):input = xx = self.fc1(x)x = self.norm(x)x = self.fc2(x)x = self.act(x)x = self.fc3(x)x = input + self.drop(x)return x

通过Local Spatial Block来增强局部特征的提取。对于给定的输入特征,它首先进行深度可分卷积,该卷积在不混合通道信息的情况下分别对每个输入通道进行操作,有效提取输入特征图的局部空间信息,同时降低计算成本和参数数量,然后进行批处理归一化,在减少过拟合的同时提供一定程度的正则化效果。 


ResGated Block

最初的MLP仍然是最广泛采用的,VMamba架构中的MLP也遵循Transformer设计,对输入序列进行非线性变换,以增强模型的表达能力。最近的研究表明,门控MLP在自然语言处理中表现出强大的性能,我们发现门控机制对视觉具有同样的潜力。在图4(d)中,本文提出ResGated Block的简单设计旨在以低计算成本提高模型的性能,RG Block从输入X^{l-2}创建两个分支X^{l-1}_{1}X^{l-1}_{2},并在每个分支上以1×1卷积的形式实现全连接层。

X^{l-1}_{2}的分支上使用深度分离卷积作为位置编码模块,并且在训练过程中通过残差级联更有效地反映梯度,这具有更低的计算成本,并且通过保留和使用图像的空间结构信息来显著提高性能。RG块采用非线性GeLU作为激活函数来控制每个级别的信息流,然后通过元素乘法与X^{l-1}_{2}的一个分支合并,然后通过1x1卷积与全局特征进行细化以混合信道信息,最后通过残差级联与原始输入X^{l-2}与隐藏层中的特征求和。RG Block可以捕获更多的全局特征,同时只带来轻微的计算成本增加,由此产生的输出特征X^{l}定义为: 

class RGBlock(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featureshidden_features = int(2 * hidden_features / 3)self.fc1 = nn.Conv2d(in_features, hidden_features * 2, kernel_size=1)self.dwconv = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1, bias=True,groups=hidden_features)self.act = act_layer()self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1)self.drop = nn.Dropout(drop)def forward(self, x):x, v = self.fc1(x).chunk(2, dim=1)x = self.act(self.dwconv(x) + x) * vx = self.drop(x)x = self.fc2(x)x = self.drop(x)return x

 

使用深度可分卷积作为主流分支上的位置编码模块,在训练过程中通过残差连接的方式对梯度进行更有效的回传,计算成本更低,并且通过保留和利用图像的空间结构信息,显著提高性能。

实际代码与论文中图片有所出入:

1、代码中是输入经过卷积Conv2d后被拆分,而论文模型结构中是拆分后分别经过两个卷积Conv2d

2、代码中最后输出经过self.drop后,并未按照论文模型结构中再一次采用残差方式进行连接


Experiments

在VOC0712数据集上进行Mamba YOLO以进行消融实验,测试模型为Mamba YOLO-T。作者的结果表2显示,线索合并为状态空间模型(SSM)保留了更多的视觉线索,也为ODSS块结构确实是最优的断言提供了证据。 

RGBlock通过逐像素获取全局相关性和全局特征来捕获逐像素的局部相关性。关于RG块设计的细节,作者还考虑了多层感知基础之上的三种变体:1)卷积MLP,它将DW-Conv添加到原始MLP;2) Res卷积MLP,其以残差级联方式将DW-Conv添加到原始MLP;3) 门控MLP,一种在门控机制下设计的MLP变体。图5说明了这些变体,表3显示了原始MLP、RG块和VOC0712数据集中每个变体的性能,以验证作者使用测试模型Mamba YOLO-T对MLP分析的有效性。作者观察到,卷积的引入并不能有效提高性能,其中在图5(d)门控MLP的变体中,其输出由两个元素乘法的线性投影组成,其中一个由残差连接的DWConv和门控激活函数组成,这实际上使模型能够通过分层结构函数传播重要特征。该实验表明,在处理复杂图像任务时,引入的卷积性能的提高与门控聚合机制非常相关,前提是它们适用于残差连通性的情况。

为了评估作者提出的基于ssm的Mamba YOLO架构的优越性和良好的可扩展性,作者将其应用于除目标检测领域外的实例分割任务。作者采用Mamba YOLO-T之上的v8分割头,并在COCOSeg数据集上对其进行训练和测试,通过Bbox AP和Mask AP等指标评估模型性能。Mamba YOLO-T-seg在每种尺寸上都显著优于YOLOv5和YOLOv8的分割模型。RTMDet基于包含深度卷积大内核的基本构建块,在动态标签分配过程中引入软标签来计算匹配成本,并在几个视觉任务中表现出出色的性能,Mamba YOLO-T-seg与Tiny相比,在Mask mAP上仍保持2.3的优势。结果如表4和图8所示。 


Conclusion

在本文中,作者重新分析了CNN和Transformer架构在目标检测领域的优缺点,并指出了它们融合的局限性。基于此,作者提出了一种基于状态空间模型架构设计并由YOLO扩展的检测器,作者重新分析了传统MLP的局限性,并提出了RG块,其门控机制和深度卷积残差连通性被设计为使模型能够在分层结构中传播重要特征。此外,为了解决Mamba架构在捕获局部依赖性方面的局限性,LSBlock增强了捕获局部特征的能力,并将它们与原始输入融合,以增强特征的表示,这显著提高了模型的检测能力。作者的目标是建立一个新的YOLO基线,前提是Mamba YOLO具有高度竞争力。

关键字:Mamba-YOLO : 基于SSM的YOLO目标检测算法(附代码)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: