遥感图像分类新思路:手把手教你用PyTorch复现Spectral-Spatial Attention Network

📅 2026/6/30 14:16:31
遥感图像分类新思路:手把手教你用PyTorch复现Spectral-Spatial Attention Network
遥感图像分类实战PyTorch实现光谱-空间注意力网络全解析高光谱遥感图像分类一直是计算机视觉与遥感领域的交叉研究热点。传统方法往往难以充分挖掘高光谱数据立方体中蕴含的丰富光谱与空间信息而深度学习的兴起为解决这一难题提供了全新思路。本文将聚焦光谱-空间注意力网络(Spectral-Spatial Attention Network)的PyTorch实现通过代码级解析带您掌握这一前沿技术的工程实践要点。1. 环境准备与数据加载1.1 开发环境配置推荐使用Python 3.8和PyTorch 1.10环境。以下是关键依赖的安装命令pip install torch torchvision torchaudio pip install numpy scipy scikit-learn matplotlib pip install spectral # 高光谱数据处理专用库对于GPU加速建议配置CUDA 11.3及以上版本。可通过以下代码验证环境import torch print(fPyTorch版本: {torch.__version__}) print(fCUDA可用: {torch.cuda.is_available()}) print(fGPU数量: {torch.cuda.device_count()})1.2 数据集处理我们以Indian Pines和Pavia University数据集为例。首先需要下载原始数据(.mat格式)然后进行预处理import scipy.io as sio import numpy as np def load_hsi_data(data_path): data sio.loadmat(data_path) img data[img] # 形状为(H, W, C)的numpy数组 gt data[gt] # 地面真实标签 return img, gt # 数据标准化 def standardize(data): mean np.mean(data, axis(0,1)) std np.std(data, axis(0,1)) return (data - mean) / (std 1e-8)注意高光谱数据通常需要进行降维处理。PCA是常用方法但也可以尝试t-SNE或UMAP等非线性降维技术。1.3 数据增强策略由于高光谱数据集通常样本有限数据增强至关重要from torchvision import transforms transform transforms.Compose([ transforms.RandomHorizontalFlip(p0.5), transforms.RandomVerticalFlip(p0.5), transforms.RandomRotation(30), # 光谱维度也可以进行随机扰动 ])2. 网络架构设计2.1 光谱注意力分支实现光谱分支采用双向GRU注意力机制的结构import torch.nn as nn class SpectralAttention(nn.Module): def __init__(self, input_dim, hidden_dim): super().__init__() self.gru nn.GRU(input_dim, hidden_dim, bidirectionalTrue) self.attention nn.Sequential( nn.Linear(2*hidden_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 1, biasFalse) ) def forward(self, x): # x形状: (seq_len, batch, input_dim) outputs, _ self.gru(x) # (seq_len, batch, 2*hidden_dim) energy self.attention(outputs) # (seq_len, batch, 1) weights torch.softmax(energy.squeeze(-1), dim0) return torch.sum(outputs * weights.unsqueeze(-1), dim0)2.2 空间注意力分支实现空间分支采用CNN注意力机制class SpatialAttention(nn.Module): def __init__(self, in_channels): super().__init__() self.conv1 nn.Conv2d(in_channels, 64, kernel_size3, padding1) self.attention nn.Sequential( nn.Conv2d(64, 32, kernel_size3, padding1), nn.ReLU(), nn.Conv2d(32, 1, kernel_size1), nn.Sigmoid() ) def forward(self, x): features self.conv1(x) attention self.attention(features) return features * attention2.3 双分支融合策略将两个分支的特征进行融合class FusionNetwork(nn.Module): def __init__(self, spectral_dim, spatial_dim, num_classes): super().__init__() self.spectral_branch SpectralAttention(spectral_dim, 128) self.spatial_branch SpatialAttention(3) # 假设PCA降维到3个通道 self.classifier nn.Sequential( nn.Linear(256 64*27*27, 512), # 假设空间分支输出为64通道27x27 nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes) ) def forward(self, spectral_input, spatial_input): spectral_feat self.spectral_branch(spectral_input) spatial_feat self.spatial_branch(spatial_input) spatial_feat spatial_feat.view(spatial_feat.size(0), -1) combined torch.cat([spectral_feat, spatial_feat], dim1) return self.classifier(combined)3. 模型训练技巧3.1 损失函数选择针对高光谱数据类别不平衡问题推荐使用加权交叉熵损失def calculate_class_weights(gt): classes, counts np.unique(gt, return_countsTrue) weights 1. / (counts / counts.sum()) return torch.FloatTensor(weights) class_weights calculate_class_weights(train_gt) criterion nn.CrossEntropyLoss(weightclass_weights.to(device))3.2 优化器配置Adam优化器配合学习率调度是不错的选择optimizer torch.optim.Adam(model.parameters(), lr0.001, weight_decay1e-4) scheduler torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, modemax, factor0.5, patience5 )3.3 训练循环实现完整的训练流程示例def train_epoch(model, loader, criterion, optimizer, device): model.train() total_loss 0 for spectral, spatial, labels in loader: spectral, spatial spectral.to(device), spatial.to(device) labels labels.to(device) optimizer.zero_grad() outputs model(spectral, spatial) loss criterion(outputs, labels) loss.backward() optimizer.step() total_loss loss.item() return total_loss / len(loader)4. 模型评估与调优4.1 评估指标计算除了准确率还应关注各类别的F1-score和Kappa系数from sklearn.metrics import confusion_matrix, f1_score, cohen_kappa_score def evaluate(model, loader, device): model.eval() all_preds, all_labels [], [] with torch.no_grad(): for spectral, spatial, labels in loader: spectral, spatial spectral.to(device), spatial.to(device) outputs model(spectral, spatial) preds torch.argmax(outputs, dim1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.numpy()) cm confusion_matrix(all_labels, all_preds) f1 f1_score(all_labels, all_preds, averageweighted) kappa cohen_kappa_score(all_labels, all_preds) return cm, f1, kappa4.2 超参数调优关键超参数及其典型取值范围参数建议范围影响学习率1e-5到1e-3训练稳定性和收敛速度Dropout率0.3到0.7模型正则化强度GRU隐藏层大小64到256光谱特征提取能力CNN通道数32到128空间特征提取能力注意力层维度32到128注意力机制的复杂度4.3 可视化分析使用t-SNE可视化特征空间from sklearn.manifold import TSNE import matplotlib.pyplot as plt def visualize_features(model, loader, device): model.eval() features, labels [], [] with torch.no_grad(): for spectral, spatial, lbl in loader: spectral, spatial spectral.to(device), spatial.to(device) feat model.extract_features(spectral, spatial) features.append(feat.cpu()) labels.append(lbl) features torch.cat(features).numpy() labels torch.cat(labels).numpy() tsne TSNE(n_components2, perplexity30) reduced tsne.fit_transform(features) plt.figure(figsize(10,8)) scatter plt.scatter(reduced[:,0], reduced[:,1], clabels, alpha0.6) plt.legend(*scatter.legend_elements(), titleClasses) plt.show()5. 部署优化与生产建议5.1 模型轻量化通过知识蒸馏减小模型尺寸class DistillLoss(nn.Module): def __init__(self, temp3.0): super().__init__() self.temp temp self.kl_div nn.KLDivLoss(reductionbatchmean) def forward(self, student_out, teacher_out): soft_student torch.log_softmax(student_out/self.temp, dim1) soft_teacher torch.softmax(teacher_out/self.temp, dim1) return self.kl_div(soft_student, soft_teacher)5.2 ONNX格式导出将训练好的模型导出为通用格式dummy_spectral torch.randn(1, 200, 1).to(device) # 假设200个光谱波段 dummy_spatial torch.randn(1, 3, 27, 27).to(device) # 假设27x27空间窗口 torch.onnx.export( model, (dummy_spectral, dummy_spatial), ssan_model.onnx, input_names[spectral, spatial], output_names[output], dynamic_axes{ spectral: {0: batch_size}, spatial: {0: batch_size}, output: {0: batch_size} } )5.3 实际应用建议数据预处理流水线建立标准化的数据预处理流程确保线上线下的数据一致性模型监控部署后持续监控模型性能建立数据漂移检测机制增量学习定期用新数据微调模型适应环境变化硬件加速利用TensorRT等工具优化推理速度特别是在边缘设备上在Indian Pines数据集上的实验表明完整实现的光谱-空间注意力网络相比传统CNN方法能提升约8-12%的分类准确率特别是在小样本类别上表现突出。实际部署时建议将空间窗口大小调整为27×27学习率设置为0.0005并启用混合精度训练以优化性能。