当前位置: 首页> 财经> 创投人物 > 免费网站建设设计制作公司_石家庄精准推广_百度指数可以查询多长时间的_腾讯广告投放平台官网

免费网站建设设计制作公司_石家庄精准推广_百度指数可以查询多长时间的_腾讯广告投放平台官网

时间:2025/8/7 4:42:05来源:https://blog.csdn.net/u014158430/article/details/146243879 浏览次数:0次
免费网站建设设计制作公司_石家庄精准推广_百度指数可以查询多长时间的_腾讯广告投放平台官网

1. 依赖库安装

如果你还没安装相关库,请先执行:

pip install torch torchaudio torchvision scikit-learn matplotlib tqdm

2. 数据加载

这里假设你有一个 音频分类数据集,其文件结构如下:

dataset/
│── train/
│   ├── class_0/
│   │   ├── audio_0.wav
│   │   ├── audio_1.wav
│   ├── class_1/
│   │   ├── audio_0.wav
│   │   ├── audio_1.wav
│── val/
│   ├── class_0/
│   ├── class_1/

实现数据加载器:

import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms# 音频数据集类
class AudioDataset(Dataset):def __init__(self, root_dir, sample_rate=16000, n_mfcc=40):self.root_dir = root_dirself.sample_rate = sample_rateself.n_mfcc = n_mfccself.classes = sorted(os.listdir(root_dir))  # 目录名作为类别self.file_paths = []self.labels = []for label, class_name in enumerate(self.classes):class_dir = os.path.join(root_dir, class_name)for file_name in os.listdir(class_dir):self.file_paths.append(os.path.join(class_dir, file_name))self.labels.append(label)self.mfcc_transform = torchaudio.transforms.MFCC(sample_rate=self.sample_rate,n_mfcc=self.n_mfcc,melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 64})def __len__(self):return len(self.file_paths)def __getitem__(self, idx):file_path = self.file_paths[idx]label = self.labels[idx]waveform, sr = torchaudio.load(file_path)if sr != self.sample_rate:resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)waveform = resampler(waveform)mfcc = self.mfcc_transform(waveform).squeeze(0)  # (n_mfcc, time)mfcc = mfcc.unsqueeze(0).repeat(3, 1, 1)  # (3, n_mfcc, time) 适配 ResNetreturn mfcc, label# 创建数据加载器
def get_dataloaders(train_dir, val_dir, batch_size=32, num_workers=2):train_dataset = AudioDataset(train_dir)val_dataset = AudioDataset(val_dir)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_loader, val_loader

3. 训练和验证

import torch.optim as optim
from tqdm import tqdmdef train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001, device="cuda"):model = model.to(device)criterion = torch.nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)for epoch in range(num_epochs):print(f"Epoch [{epoch+1}/{num_epochs}]")# 训练阶段model.train()total_loss, correct, total = 0, 0, 0for inputs, labels in tqdm(train_loader):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()total += labels.size(0)correct += (outputs.argmax(dim=1) == labels).sum().item()train_acc = correct / totalprint(f"Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")# 验证阶段model.eval()total_loss, correct, total = 0, 0, 0with torch.no_grad():for inputs, labels in tqdm(val_loader):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)total_loss += loss.item()total += labels.size(0)correct += (outputs.argmax(dim=1) == labels).sum().item()val_acc = correct / totalprint(f"Val Loss: {total_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")return model

4. 混淆矩阵可视化

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplaydef evaluate_model(model, val_loader, device="cuda"):model.eval()all_preds = []all_labels = []with torch.no_grad():for inputs, labels in tqdm(val_loader):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)preds = outputs.argmax(dim=1).cpu().numpy()labels = labels.cpu().numpy()all_preds.extend(preds)all_labels.extend(labels)return np.array(all_labels), np.array(all_preds)def plot_confusion_matrix(y_true, y_pred, class_names):cm = confusion_matrix(y_true, y_pred)disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)disp.plot(cmap=plt.cm.Blues, values_format="d")plt.xticks(rotation=45)plt.show()

5. 运行完整训练流程

if __name__ == "__main__":train_dir = "dataset/train"val_dir = "dataset/val"batch_size = 32num_epochs = 10device = "cuda" if torch.cuda.is_available() else "cpu"# 加载数据train_loader, val_loader = get_dataloaders(train_dir, val_dir, batch_size)# 初始化模型model = ResNetSE(num_classes=len(os.listdir(train_dir)))# 训练模型trained_model = train_model(model, train_loader, val_loader, num_epochs=num_epochs, device=device)# 计算混淆矩阵y_true, y_pred = evaluate_model(trained_model, val_loader, device=device)# 绘制混淆矩阵class_names = sorted(os.listdir(train_dir))plot_confusion_matrix(y_true, y_pred, class_names)

6. 总结

数据加载

  • 通过 torchaudio 提取 MFCC 特征,并适配 ResNet 输入格式。

ResNet-SE 训练

  • 训练过程包含 Adam 优化器 + 交叉熵损失,支持 GPU 训练。

混淆矩阵可视化

  • 通过 sklearn 计算混淆矩阵,并绘制 分类效果图

改进方向

🚀 模型优化

  • 使用 ResNet-34/50 替代 ResNet-18 提升表达能力。
  • 结合 SpecAugment 增强数据,提高鲁棒性。

推理加速

  • 采用 TorchScript / ONNX 进行模型导出,提高部署效率。

💡 数据增强

  • 额外使用 时域和频域增强(如 torchaudio.transforms.TimeMasking)。

这样,你就能完整训练 ResNet-SE + MFCC 进行音频分类,并分析模型性能了!💪🚀

关键字:免费网站建设设计制作公司_石家庄精准推广_百度指数可以查询多长时间的_腾讯广告投放平台官网

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: