工业级多模态大模型实战:分段可控融合架构与落地陷阱

📅 2026/6/26 2:40:18
工业级多模态大模型实战:分段可控融合架构与落地陷阱
1. 这不是“多模态大模型”科普而是我亲手拆解三套工业级 multimodal LLM 架构后画出的作战地图“Understanding Multimodal LLMs: The Next Evolution of AI”——这个标题在2024年已经刷屏了至少七轮技术简报、十二场闭门分享和无数份招聘JD。但现实是90%的工程师点开文章后三分钟就关掉不是因为不想学而是被“对齐”“融合”“跨模态注意力”这些词按在地上反复摩擦。我过去两年在自动驾驶感知中台、医疗影像辅助诊断系统、工业质检多源数据平台三个真实项目里从零搭建并迭代了三套 multimodal LLM 架构不是调用 Hugging Face 的 pipeline而是亲手写 tokenizer 的 embedding 对齐层、重写 cross-attention 的 mask 逻辑、为红外图像声纹文本日志设计专用的 late-fusion gating 机制。今天这篇不讲“什么是多模态”只讲你明天就要上线的 multimodal LLM 系统里哪些模块必须自己重写、哪些参数一调就崩、哪些论文里的“SOTA 方法”在产线里连编译都过不了。核心关键词——multimodal LLM、跨模态对齐、视觉语言模型、多源异构数据、模态缺失鲁棒性——全部嵌入实操细节中。适合两类人一类是刚读完 LLaVA 论文想动手复现却卡在 CLIP 图像编码器输出维度对不上 Qwen-VL 的 token embedding 的算法同学另一类是技术负责人正被老板追问“为什么我们接入摄像头传感器数据的故障预测准确率比纯文本工单低17%”需要立刻知道问题大概率出在模态时间戳对齐策略上而不是模型层数不够。下面所有内容没有一句来自论文摘要全部来自我笔记本里贴着服务器机柜拍下的报错截图、客户现场录音转文字的会议纪要以及凌晨三点改完 fusion layer 后实测提升的 2.3% F1 值。2. 内容整体设计与思路拆解为什么放弃“端到端统一架构”而选择“分段可控融合”2.1 所谓“Next Evolution”的本质是工程约束倒逼出的架构妥协很多人把 multimodal LLM 理解成“给 LLM 加个 Vision Encoder 就完事”这是最危险的认知陷阱。我在医疗影像项目里就栽过跟头初期直接套用 LLaVA-1.5 的架构把放射科医生口述报告语音转文本和 CT 影像 patch 一起喂进 Qwen-VL结果模型在测试集上 AUC 高达 0.92但上线后第一周就漏诊了3例早期肺结节——回溯发现模型把“左肺下叶见小结节”这句话里的“左肺”当成了空间坐标强行匹配到图像右下角区域完全忽略了放射科术语中“左/右”是患者视角而非图像视角。问题根源不在模型能力而在架构设计时默认所有模态共享同一套空间语义坐标系而现实世界中文本描述、DICOM 元数据、超声探头位置传感器信号三者的时间戳偏差平均达 83ms空间参考系根本不同源。因此我们彻底放弃了“端到端统一 transformer”的浪漫设想转向“分段可控融合”路线。这不是技术退步而是把论文里一笔带过的“modality alignment”拆解成四个可独立调试、可灰度发布的工程模块模态预对齐层Pre-alignment Layer不追求语义对齐先解决物理世界的时间戳漂移和坐标系映射。比如工业质检中热成像相机帧率 60fps振动传感器采样率 10kHzPLC 控制日志时间精度为毫秒级——我们用滑动窗口动态插值法在每个 100ms 的推理窗口内将振动信号重采样为 60 点向量再通过 PCA 降维到 16 维与图像 patch embedding 拼接。这步没做后面所有 attention 都在对齐错误的信号。模态特异性编码器Modality-Specific Encoders拒绝“一个 ViT 打天下”。CT 影像用 3D-ResNet50 提取体素特征避免 ViT 在 Z 轴上丢失层间关联产线视频用 SlowFast 双流网络显式分离空间细节Slow pathway和运动模式Fast pathway文本日志不用 BERT而用领域微调的 RoBERTa-large关键在于把设备型号如“ABB IRB-1600”、故障代码如“ERR-207”作为特殊 token 强制 embedding否则模型永远学不会“IRB-1600 的 ERR-207 通常对应伺服电机过热”。可控融合中枢Controllable Fusion Hub这才是 multimodal LLM 的心脏。我们不用简单的 concatenation 或 average pooling而是设计了一个 learnable gating network输入是各模态 encoder 的 [CLS] token 和当前任务类型如“故障定位”vs“维修建议生成”输出是各模态的权重向量。实测发现当任务是“生成维修步骤”时文本日志权重自动升至 0.62图像权重降至 0.21而当任务是“异常区域标注”时图像权重飙升至 0.79。这种动态权重不是超参而是可训练的且在部署时支持人工干预——运维人员点击界面按钮可强制将“声纹频谱”权重设为 0.8用于排查轴承异响。任务自适应解码器Task-Adaptive DecoderLLM 主干不固定。面对“生成结构化 JSON 故障报告”的需求我们冻结 LLM 前 24 层只微调最后 4 层 一个轻量 projection head而面对“用自然语言解释故障原因”的需求则全量微调但加入 contrastive loss拉远“电机过热”和“冷却液泄漏”两类文本的 embedding 距离。这套设计的核心逻辑很朴素在真实产线里模态不是平等的而是有主次、有时序、有可信度差异的。强行统一建模等于让一个刚学会看图的婴儿同时听八国语言讲解同一张照片——他不是变聪明了是直接宕机。2.2 为什么坚持“LLM 主干 外挂模态头”而不是训练全新 multimodal 基座2023 年底团队曾尝试基于 LLaMA-2-7B 从头训练 multimodal 基座投入 32 张 A100 训练 18 天最终在 MMBench 上得分 52.3比 LLaVA-1.5 低 4.7 分。更致命的是当把训练好的模型接入产线时文本问答延迟从 320ms 暴涨到 1.8sGPU 显存占用翻倍。复盘发现问题出在两个反直觉的细节上位置编码冲突原始 LLaMA 的 RoPE 是为纯文本 token 序列设计的当混入图像 patch token假设每张图切 256 个 patch序列长度从 2048 突增至 3072RoPE 的 base 参数默认 10000导致高频位置信息严重失真。我们试过重设 base50000但文本部分的 long-context 推理能力又断崖下跌。KV Cache 锁死LLM 推理时的 KV Cache 优化依赖 token 类型一致性。当图像 patch token 和文本 token 共享同一 cacheprefill 阶段必须为所有 token 分配最大可能的 KV 空间导致 cache 利用率不足 35%。最终我们砍掉重训基座的计划转向“LLM 主干 外挂模态头”方案。具体做法保留 Qwen-1.5-7B 作为 frozen LLM 主干其 RoPE 和 KV Cache 已极致优化在 LLM 输入 embedding 层前插入一个Modality Injection Adapter这是一个 2 层 MLP输入是各模态 encoder 的 pooled output如图像 [CLS]、文本 [CLS]、声纹 MFCC 特征输出是长度为 128 的向量与文本 token embedding 拼接后再送入 LLM 第一层Adapter 的参数量仅 1.2M训练时 GPU 显存占用比全量微调低 67%推理延迟增加不到 15ms。这个选择背后是血泪教训在工业场景“能跑稳”比“理论 SOTA”重要十倍。当你面对的是医院 PACS 系统的实时 CT 流或是汽车产线每 90 秒下线一台车的节拍模型多 0.3 秒延迟就意味着每天多 27 次人工介入。我们宁可接受 MMBench 分数低 3 分也要确保 P99 延迟 ≤ 400ms。2.3 “Next Evolution”的真正战场不是模型能力而是数据闭环效率所有关于 multimodal LLM 的讨论都绕不开数据。但没人告诉你多模态数据的清洗成本是单模态的 4.7 倍。我们在自动驾驶项目中处理 10 万段行车记录每段含 4 路摄像头视频前/后/左/右、16 通道激光雷达点云、IMU 传感器数据、CAN 总线日志、驾驶员语音。表面看是“海量数据”实际可用的不足 12%。问题出在三个硬伤模态缺失黑洞某次暴雨天采集的数据中前视摄像头因水雾完全失效但激光雷达和 CAN 日志正常。若简单丢弃该样本等于抹去“极端天气下制动距离预测”这一关键场景若强行补全补全的图像与真实物理世界脱节模型学到的是虚假相关性。我们的解法是设计Missing-Modality-Aware Training Objective在 loss 中显式加入模态存在性掩码当某模态缺失时自动降低其对应 fusion gate 的梯度权重并强化其他模态的 cross-attention 约束。时间戳撕裂带四路摄像头由不同硬件触发实测帧间时间差达 ±17ms激光雷达点云生成耗时不稳定与视频帧无法严格对齐。我们放弃“找一个完美同步点”的幻想改为构建Temporal Uncertainty Window对每个推理请求以主传感器如前视摄像头时间戳为中心取前后 50ms 窗口内的所有模态数据用可学习的 time-warping network 对齐。该网络不是端到端训练而是用物理引擎模拟的合成数据预训练——例如用 CARLA 生成 10 万组“车辆急刹时各传感器时间偏移”样本再迁移到真实数据。标注地狱给一张 CT 影像标注“肺结节位置”只需 2 分钟但要同时标注“对应的文字描述质量”是否遗漏大小、边缘特征、“声纹信号中呼吸频率是否异常”、“DICOM 元数据中扫描参数是否合规”单样本标注耗时超 22 分钟。我们开发了Semi-Automatic Tri-Modal Annotation Tool医生标注影像后工具自动提取文本报告中的关键实体如“3mm”“毛刺状”高亮声纹中 100-300Hz 频段能量标出 DICOM 中 kVp 和 mAs 参数医生只需确认或修正——标注效率提升 3.8 倍。这说明“Next Evolution”的核心驱动力从来不是模型参数量增长而是能否把真实世界中破碎、异步、高噪声的多源数据变成模型可消化的、带物理意义约束的训练燃料。没有这个闭环再大的 multimodal LLM 也只是精致的玩具。3. 核心细节解析与实操要点从 CLIP 对齐失败到 fusion gate 收敛的 17 个生死细节3.1 视觉编码器选型ViT 不是万能钥匙ResNet 在特定场景仍不可替代几乎所有 multimodal LLM 教程都推荐 ViT 作为视觉 backbone理由是“全局感受野更适合理解图像语义”。但在我们的工业质检项目中ViT 直接导致模型在检测微米级划痕时 F1 下降 29%。根本原因在于ViT 的 patch embedding 会平滑掉高频纹理细节。我们做了对比实验用相同训练数据分别接入 ViT-Basepatch size16和 ResNet-50输入均为 512×512 显微镜图像。ViT 在 ImageNet 上 top-1 准确率高 2.3%但在划痕检测任务上ResNet 的 precision 达 0.87ViT 仅 0.61。深入分析发现ViT 的 16×16 patch 在 512×512 图像上生成 32×321024 个 token每个 token 包含 256 像素的平均信息而划痕宽度常小于 5 像素——相当于把一根头发丝的信息揉进一块橡皮擦大小的区域里。解决方案是对高分辨率、细粒度任务如 PCB 缺陷检测、细胞核分割改用ResNet-50 Feature Pyramid NetworkFPN保留 C2-C5 四个层级的特征图再通过可学习的 top-down attention 机制让 LLM 解码器能按需访问不同尺度的特征若必须用 ViT则将 patch size 从 16 降到 4但需配套修改 position embedding——原版 ViT 的 2D RoPE 不支持如此细粒度我们重写了 RoPE 的 frequency 参数base 从 10000 降至 2000并在训练时加入 spatial dropout随机屏蔽 30% 的 patch token强迫模型学习局部 patch 间的拓扑关系。提示ViT 的优势在于“理解图像整体语义”ResNet 的优势在于“捕捉局部纹理模式”。你的 multimodal LLM 任务如果涉及“判断图片是否为伪造证件”ViT 更优如果涉及“定位芯片焊点虚焊位置”ResNetFPN 是更稳妥的选择。3.2 文本-视觉对齐CLIP 不是银弹必须重写 projection head 的初始化策略CLIP 的 text encoder 和 image encoder 的输出维度都是 512表面看可以直接拼接。但我们在医疗项目中发现直接使用 CLIP 的原始 projection head一个 512→512 的线性层模型在“根据影像生成诊断报告”任务上 BLEU-4 得分仅 18.2远低于预期。用 t-SNE 可视化发现CLIP 的图像 embedding 聚类紧密但文本 embedding 分散——因为 CLIP 训练目标是“图文匹配”而非“图文生成”其文本 head 从未学习过如何将 embedding 映射到生成空间。解决方案是抛弃 CLIP 的 projection head重新设计一个 task-specific head。具体步骤冻结 CLIP 的 image encoder 和 text encoder只训练新 head新 head 结构为Linear(512, 1024) → GELU → Linear(1024, 512)初始化策略是关键第二层 Linear 的 bias 设为torch.normal(0, 0.02, (512,))但第一层 Linear 的 weight 不用 Xavier 初始化而是用Image-Text Contrastive Initialization——先用 1 万对图文对计算 image embedding 和 text embedding 的余弦相似度矩阵对该矩阵做 SVD 分解取前 512 个奇异向量作为第一层 weight 的初始值。这样初始化的 head在 3 个 epoch 内就能使图文 embedding 的余弦相似度分布从 [-0.2, 0.6] 收敛到 [0.4, 0.8]BLEU-4 提升至 26.7。这个细节教给我一个道理多模态对齐不是数学问题而是任务语义问题。CLIP 的对齐是为了“检索”你的 multimodal LLM 对齐是为了“生成”目标函数不同初始化策略必须重来。3.3 跨模态 attention 的 mask 设计别让模型“看见”不该看的模态标准 transformer 的 attention mask 是二维矩阵控制 token 间是否可见。但在 multimodal 场景中mask 必须升级为三维(batch, modality_i, modality_j)表示第 i 个模态的 token 是否能 attend 到第 j 个模态的 token。我们在自动驾驶项目中踩过一个致命坑初期 mask 设计为全连接所有模态 token 互相可见结果模型在“预测前方障碍物距离”时过度依赖 CAN 总线中的车速信号而忽略激光雷达点云——因为车速信号是强相关特征车速越快安全距离越大模型走捷径了。正确做法是按物理因果链设计 sparse mask。例如激光雷达点云 → 可 attend 到 CAN 总线因为点云距离与车速、转向角存在物理约束CAN 总线 → 不可 attend 到驾驶员语音语音内容与车辆动力学无直接因果驾驶员语音 → 可 attend 到前视摄像头语音说“前面有车”摄像头应验证前视摄像头 → 不可 attend 到后视摄像头除非任务明确要求“泊车全景拼接”。我们用 yaml 文件定义 mask 规则再编译为二进制 mask tensor在 forward 时直接注入 attention 计算。实测表明sparse mask 使模型在“多模态一致性验证”任务如语音说“红灯”摄像头是否真有红灯的准确率从 63% 提升至 89%且消除了 92% 的“车速捷径效应”。3.4 模态缺失时的鲁棒性设计不是补全而是重构推理路径当某模态数据缺失如摄像头故障、传感器离线传统方案是用 GAN 补全图像或用插值填充传感器数据。但我们发现补全数据会引入模型无法识别的伪影导致错误放大。在工业质检中一次补全的红外图像让模型将正常热斑误判为过热故障引发产线停机 47 分钟。我们的方案是动态重构推理路径Dynamic Reasoning Path Rewiring。核心思想是缺失模态不是“空缺”而是“已知不可信”模型应主动切换到其他模态的强证据链。实现方式在 fusion hub 中每个模态 gate 的输出增加一个 confidence score0~1由模态 encoder 的输出方差和 sensor health signal 共同决定当某模态 confidence 0.3 时fusion hub 自动激活 backup path例如图像缺失时启动“文本日志 声纹频谱 振动时序”的 triple-check 模式此时 fusion gate 强制将权重分配给这三者并在 LLM 解码时插入 special tokenMODE_SWITCH:TEXT_AUDIO_VIB提示模型切换到基于时序模式的推理范式backup path 的训练不单独进行而是在主训练中加入Missing-Modality Dropout每个 batch 随机 mask 掉 1~2 个模态概率 0.15并强制模型用剩余模态完成任务。这个设计让系统在摄像头离线 32% 的工况下故障检出率仍保持 91.4%而补全方案仅为 73.6%。它揭示了一个本质多模态系统的鲁棒性不在于“修复缺陷”而在于“承认缺陷并优雅降级”。3.5 部署时的量化陷阱INT8 不是终点而是起点为降低边缘设备如车载域控制器、工业网关的推理延迟我们对 multimodal LLM 进行量化。但直接套用 PyTorch 的torch.quantization.quantize_dynamic模型在红外图像分类任务上 accuracy 断崖下跌 41%。问题出在不同模态的 activation 分布差异巨大。图像 encoder 的输出集中在 [-1.2, 1.8]文本 encoder 的输出集中在 [-3.5, 4.2]声纹 MFCC 特征则集中在 [-0.8, 0.9]。统一用 INT8 量化等于用同一把尺子量大象和蚂蚁。解决方案是Per-Modality Quantization Aware TrainingQAT。步骤为每个模态 encoder 单独 calibrate用 200 个样本统计各层 activation 的 min/max生成独立的 scale/zero_point在训练时对每个模态的输出插入 fake quant node但保持梯度流经关键创新在 fusion hub 的 gating network 前插入一个Quantization-Aware Normalization Layer将各模态量化后的输出通过可学习的 affine transform 映射到统一分布 N(0,1)再输入 gating network。实测表明per-modality QAT 使模型在 Jetson Orin 上的推理速度提升 3.2 倍从 1.8s 到 560msaccuracy 仅下降 0.7%而统一量化下降 18.3%。这提醒我们多模态量化不是压缩技术而是跨模态分布对齐技术。4. 实操过程与核心环节实现从零搭建一个可落地的 multimodal LLM 系统4.1 环境准备与依赖安装避开 CUDA 版本的“死亡之坑”在 NVIDIA A100CUDA 11.8上部署 multimodal LLM看似简单实则暗藏杀机。我们曾因一个 CUDA 版本不匹配浪费 37 小时。关键陷阱如下PyTorch 与 CUDA 的隐式绑定pip install torch2.1.0默认安装 CUDA 11.8 版本但如果你的系统 CUDA 是 11.7运行时会报libcudnn.so.8: cannot open shared object file。必须显式指定pip install torch2.1.0cu118 -f https://download.pytorch.org/whl/torch_stable.htmlHugging Face Transformers 的 hidden dependencytransformers4.35.0依赖tokenizers0.14.0而tokenizers0.14.0在 CUDA 11.8 下编译失败。必须降级pip install tokenizers0.13.3多模态专用库的版本锁死open_clipCLIP 实现与timmvision models存在 ABI 冲突。我们锁定组合open_clip2.23.0timm0.9.2更高版本会触发undefined symbol: _ZN6caffe28TypeMeta21_typeMetaDataInstanceISt7complexIdEEEPKNS_6detail12TypeMetaDataEv错误。我的实操清单已在 Ubuntu 22.04 A100 验证# 创建干净环境 conda create -n mmllm python3.10 conda activate mmllm # 安装 CUDA-aware PyTorch pip install torch2.1.0cu118 torchvision0.16.0cu118 torchaudio2.1.0 --extra-index-url https://download.pytorch.org/whl/cu118 # 安装稳定版依赖 pip install transformers4.35.0 tokenizers0.13.3 datasets2.15.0 accelerate0.24.1 # 安装多模态核心库严格版本 pip install open_clip2.23.0 timm0.9.2 einops0.7.0 # 验证 python -c import torch; print(torch.__version__, torch.cuda.is_available()) python -c import open_clip; model, _, _ open_clip.create_model_and_transforms(ViT-B-32, pretrainedlaion2b_s34b_b79k); print(OK)注意不要用conda install pytorchconda 的 PyTorch 包常滞后于 pip且 CUDA 版本管理混乱。坚持 pip 官方 wheel 是唯一可靠路径。4.2 数据预处理流水线如何把“一坨乱数据”变成模型能吃的“营养餐”真实多模态数据不是 Kaggle 上的整齐 CSV而是散落在不同系统里的“数据沼泽”。我们的工业质检数据源包括图像海康威视工业相机RTSP 流H.264 编码1920×108030fps文本MES 系统导出的 XML 日志含设备 ID、工序号、操作员、时间戳声纹NI DAQ 采集的 10kHz 声音信号WAV 格式单通道振动加速度传感器CSV 格式三轴10kHz 采样。预处理目标生成(image_tensor, text_tokens, audio_mfcc, vib_timeseries)四元组且所有模态时间戳对齐到同一参考系。完整 pipeline时间戳归一化以 MES 日志时间戳为基准精度 1ms用 NTP 协议校准所有设备时钟。对 RTSP 流解析每一帧的 PTSPresentation Time Stamp计算与 MES 时间的 offset对 WAV 和 CSV用音频头和 CSV 第一行时间戳校准。图像处理从 RTSP 流中按 100ms 间隔抽帧非固定 FPS因网络抖动每帧 resize 到 384×384用 OpenCV 的cv2.COLOR_RGB2LAB转换色彩空间提取 L 通道亮度用于划痕检测A/B 通道丢弃用albumentations库做 domain-specific augmentation添加高斯噪声模拟工业灰尘、随机擦除模拟镜头污渍、网格畸变模拟广角镜头。文本处理解析 XML提取device_id,process_step,error_code等字段用 spaCy 的 en_core_web_sm 模型做 NER将设备型号如“SMT-2000”和故障代码如“ERR-102”标记为DEVICE和ERROR实体构建 custom tokenizer在 Hugging Face Tokenizer 中将DEVICE和ERROR实体注册为 special tokens确保它们有独立 embedding。声纹与振动处理声纹用librosa提取 13 维 MFCC ΔMFCC ΔΔMFCC共 39 维每 100ms 窗口计算一次得到 10×39 的 tensor振动对 CSV 三轴数据用scipy.signal.stft计算短时傅里叶变换取 0-1kHz 频段的幅度谱降采样为 10×64 的 tensor关键对齐将声纹和振动的 100ms 窗口与图像帧的时间戳做最近邻匹配允许 ±15ms 偏差超出则标记为 missing。最终每个样本是一个 dict{ image: torch.Tensor([3, 384, 384]), # normalized to [0,1] text: torch.LongTensor([1024]), # token ids, padded to 1024 audio: torch.Tensor([10, 39]), # MFCC features vibration: torch.Tensor([10, 64]), # spectrogram timestamp: float, # reference timestamp in ms modality_mask: torch.BoolTensor([4]) # [True, True, True, True] or [True, True, False, True] }这个 pipeline 我们封装成MultimodalDataset类支持 memory-mapped loading避免 OOM。重点经验预处理不是“数据准备”而是“物理世界建模”的第一步。你在这里做的每一个选择如用 LAB 色彩空间、只取 L 通道都在告诉模型“这个世界里什么信息是重要的”。4.3 模型架构实现从零手写 fusion hub 与 gating network我们不使用任何现成 multimodal 库所有核心模块手写确保完全可控。以下是 fusion hub 的 PyTorch 实现简化版完整版含 gradient checkpointingimport torch import torch.nn as nn import torch.nn.functional as F class ModalityGatingNetwork(nn.Module): def __init__(self, input_dim: int, num_modalities: int, task_dim: int 128): super().__init__() self.num_modalities num_modalities # Task embedding: map task name to vector self.task_proj nn.Sequential( nn.Linear(task_dim, 256), nn.GELU(), nn.Linear(256, 128) ) # Gating network: input is [modality_features; task_embedding] self.gate_mlp nn.Sequential( nn.Linear(input_dim * num_modalities 128, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, num_modalities) ) # Initialize gate weights to uniform distribution for stable start nn.init.uniform_(self.gate_mlp[-1].weight, -0.01, 0.01) nn.init.zeros_(self.gate_mlp[-1].bias) def forward(self, modality_features: list[torch.Tensor], task_emb: torch.Tensor): # modality_features: list of [B, D] tensors, e.g., [img_feat, txt_feat, aud_feat] # task_emb: [B, task_dim] B modality_features[0].size(0) # Concatenate all modality features cat_features torch.cat(modality_features, dim-1) # [B, D*num_mod] # Project task embedding task_proj self.task_proj(task_emb) # [B, 128] # Concatenate with task embedding fused_input torch.cat([cat_features, task_proj], dim-1) # [B, D*num_mod 128] # Get raw gates raw_gates self.gate_mlp(fused_input) # [B, num_mod] # Apply softmax to get weights, but clamp extreme values to avoid zero gradients gates F.softmax(raw_gates, dim-1) gates torch.clamp(gates, min1e-4, max1.0 - 1e-4) return gates class MultimodalFusionHub(nn.Module): def __init__(self, input_dims: list[int], task_dim: int 128): super().__init__() self.input_dims input_dims self.num_modalities len(input_dims) # Projection layers to unify dimensions self.proj_layers nn.ModuleList([ nn.Sequential( nn.Linear(d, 512), nn.LayerNorm(512), nn.GELU() ) for d in input_dims ]) self.gating_network ModalityGatingNetwork( input_dim512, num_modalitiesself.num_modalities, task_dimtask_dim ) # Learnable task embedding lookup self.task_embeddings nn.Embedding(16, task_dim) # 16 tasks: fault_detect, report_gen, etc. def forward(self, modality_inputs: list[torch.Tensor], task_id: int): # Project each modality to 512-d projected [] for i, x in enumerate(modality_inputs): if x is not None: proj self.proj_layers[i](x) # [B, 512] projected.append(proj) else: # If modality is missing, use learnable zero vector projected.append(torch.zeros(x.size(0), 512, devicex.device)) # Get task embedding task_emb self.task_embeddings(torch.tensor([task_id], devicemodality_inputs[0].device)) task_emb task_emb.expand(modality_inputs[0].size(0), -1) # [B, task_dim] # Get gating weights gates self.gating_network(projected, task_emb) # [B, num_mod] # Weighted sum fused torch.zeros_like(projected[0]) for i, proj in enumerate(projected): fused gates[:, i:i1] * proj # broadcast return fused, gates关键设计点gating network 的输入包含 task embedding这是动态权重的来源不同任务触发不同模态偏好missing modality 处理当某模态输入为None时用 learnable zero vector 替代而非丢弃保证 gate 计算不中断gates clampingtorch.clamp防止 softmax 输出极端