当前位置: 首页> 游戏> 手游 > 河南搜索引擎优化_宁波网络营销策划公司_生成关键词的软件免费_营销活动推广策划

河南搜索引擎优化_宁波网络营销策划公司_生成关键词的软件免费_营销活动推广策划

时间:2025/7/12 4:47:53来源:https://blog.csdn.net/m0_66434421/article/details/147355698 浏览次数:0次
河南搜索引擎优化_宁波网络营销策划公司_生成关键词的软件免费_营销活动推广策划

手搓LeNet-5(基础模型)实现交通标志识别

  • 一、环境准备
    • 1. 安装Python环境
    • 2. 安装CUDA(可选,仅需GPU加速时)
    • 3. 配置虚拟环境
    • 4. 安装PyTorch核心库
    • 5. 安装辅助库
    • 6. 验证安装
    • 7. 准备数据集
    • 8.常见问题处理
  • 二、 数据集处理
    • 三、 模型实现
    • 四、训练流程
  • 五、模型部署
      • 5.1 导出为ONNX格式
      • 5.2 使用Flask部署服务
      • 5.3 测试API
  • 六、总结

本文将使用PyTorch从零实现经典的LeNet-5模型,并在交通标志识别数据集上进行训练和部署。完整代码可直接运行。


一、环境准备

1. 安装Python环境

  1. 访问Python官网下载安装包:
    python 官网
  2. 选择 Python 3.8+ 版本(推荐3.8.10)
  3. 安装时勾选 “Add Python to PATH”

2. 安装CUDA(可选,仅需GPU加速时)

  1. 建议搭配:CUDA 11.8 + cuDNN 8.6.0
    CUDA+cuDNN 详细安装配置教程

3. 配置虚拟环境

  1. 打开命令提示符(CMD)或PowerShell
  2. 创建并激活虚拟环境(激活后命令行前缀会显示 (lenet_env)):
    # 创建虚拟环境
    python -m venv lenet_env# 激活环境
    .\lenet_env\Scripts\activate
    

4. 安装PyTorch核心库

  1. 根据是否使用GPU选择命令:
    # GPU版本(需CUDA 11.8)推荐
    pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 -f https://mirrors.aliyun.com/pytorch-wheels/cu118/# 或CPU版本
    pip install torch torchvision torchaudio
    

5. 安装辅助库

  1. 安装其他库
    pip install matplotlib numpy flask requests onnx onnxruntime
    

6. 验证安装

  1. 创建 check_env.py 文件并运行:
    import torchprint("PyTorch版本:", torch.__version__)
    print("CUDA可用:", torch.cuda.is_available())
    print("设备数量:", torch.cuda.device_count())
    
    预期输出示例:
    PyTorch版本: 2.3.1+cu118
    CUDA可用: True
    设备数量: 1
    

7. 准备数据集

  1. 下载GTSRB数据集:
  • 训练集:https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Training_Images.zip
  • 测试集:https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip
  1. 手动解压文件到以下目录结构:
     C:/└─your_project/├─data/│   ├─train/│   │   └─GTSRB/Final_Training/Images/...│   └─test/│       └─GTSRB/Final_Test/Images/...└─code/
    

8.常见问题处理

  1. CUDA不可用
  • 检查显卡驱动是否为最新版本
  • 确保安装的PyTorch版本与CUDA版本匹配
  • 运行 nvidia-smi 验证显卡识别
  1. 数据集路径错误
  • 使用绝对路径(如 C:/your_project/data/train
  • 确保解压后的文件夹层级正确
  1. 内存不足
  • 降低batch_size参数(建议从64改为32)
  • 关闭其他占用显存的程序

二、 数据集处理

使用德国交通标志识别基准(GTSRB)数据集:

import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader# 数据预处理
transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载数据集
train_set = datasets.ImageFolder(root='./data/train', transform=transform)
test_set = datasets.ImageFolder(root='./data/test', transform=transform)# 创建数据加载器
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)print(f"训练集大小: {len(train_set)}")
print(f"测试集大小: {len(test_set)}")
print(f"类别数量: {len(train_set.classes)}")

三、 模型实现

LeNet-5的PyTorch实现:

import torch.nn as nnclass LeNet5(nn.Module):def __init__(self, num_classes=43):super(LeNet5, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, kernel_size=5),  # 输入通道改为3(RGB)nn.ReLU(),nn.MaxPool2d(kernel_size=2),nn.Conv2d(6, 16, kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, num_classes))def forward(self, x):x = self.features(x)x = torch.flatten(x, 1)x = self.classifier(x)return xmodel = LeNet5()
print(model)

四、训练流程

训练配置与执行:

import torch.optim as optimdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 训练循环
for epoch in range(20):model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 验证model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f"Epoch [{epoch+1}/20] Loss: {running_loss/len(train_loader):.4f} | Acc: {100*correct/total:.2f}%")# 保存模型
torch.save(model.state_dict(), "lenet5_traffic_sign.pth")

五、模型部署

5.1 导出为ONNX格式

dummy_input = torch.randn(1, 3, 32, 32).to(device)
torch.onnx.export(model, dummy_input, "lenet5.onnx", input_names=["input"], output_names=["output"])

5.2 使用Flask部署服务

from flask import Flask, request, jsonify
from PIL import Image
import numpy as npapp = Flask(__name__)
model.load_state_dict(torch.load("lenet5_traffic_sign.pth"))
model.eval()def preprocess_image(image):transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])return transform(image).unsqueeze(0)@app.route('/predict', methods=['POST'])
def predict():if 'file' not in request.files:return jsonify({'error': 'No file uploaded'})file = request.files['file']image = Image.open(file.stream).convert('RGB')tensor = preprocess_image(image).to(device)with torch.no_grad():outputs = model(tensor)_, predicted = torch.max(outputs, 1)return jsonify({'class_id': predicted.item(), 'class_name': train_set.classes[predicted.item()]})if __name__ == '__main__':app.run(host='0.0.0.0', port=5000)

5.3 测试API

使用curl测试:

curl -X POST -F "file=@test_sign.jpg" http://localhost:5000/predict

六、总结

通过本文我们实现了:

  1. LeNet-5的PyTorch实现
  2. 交通标志数据集的加载与处理
  3. 模型的训练与验证
  4. 生产环境部署方案

完整代码需配合GTSRB数据集使用,数据集可从这里下载。建议使用GPU加速训练过程。

关键字:河南搜索引擎优化_宁波网络营销策划公司_生成关键词的软件免费_营销活动推广策划

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: