当前位置: 首页> 汽车> 维修 > 微信管理办法_镇江网站制作咨询_aso优化教程_百度推广官网登录

微信管理办法_镇江网站制作咨询_aso优化教程_百度推广官网登录

时间:2025/8/28 5:03:34来源:https://blog.csdn.net/Sxiaocai/article/details/143863927 浏览次数: 0次
微信管理办法_镇江网站制作咨询_aso优化教程_百度推广官网登录

介绍

        本文将介绍如何使用 PyTorch 实现一个简化版的 GoogLeNet 网络来进行 MNIST 图像分类。GoogLeNet 是 Google 提出的深度卷积神经网络(CNN),其通过 Inception 模块大大提高了计算效率并提升了分类性能。我们将实现一个简化版的 GoogLeNet,用于处理 MNIST 数据集,该数据集由手写数字图片组成,适合用于小规模的图像分类任务。

项目结构

        我们将代码分为两个部分:

  • 训练脚本 train.py:包括数据加载、模型构建、训练过程等。
  • 测试脚本 test.py:用于加载训练好的模型并在测试集上评估性能。

项目依赖

        在开始之前,我们需要安装以下 Python 库:

  • torch:PyTorch 深度学习框架
  • torchvision:提供数据加载和图像变换功能
  • matplotlib:用于可视化

        可以通过以下命令安装所有依赖:

pip install -r requirements.txt

  requirements.txt 文件内容如下:

torch==2.0.1
torchvision==0.15.0
matplotlib==3.6.3

数据预处理与加载

1. 数据加载和预处理

        在训练模型之前,我们需要对 MNIST 数据集进行预处理。以下是数据加载和预处理的代码:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transformsdef get_data_loader(batch_size=64, train=True):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 正规化到 [-1, 1] 范围])dataset = datasets.MNIST(root='./data', train=train, download=True, transform=transform)return DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

        这里,我们使用了 transforms.Compose 来进行数据预处理,包括将图像转换为 Tensor 格式,并进行归一化处理。


训练部分:train.py

2. 模型定义:简化版 GoogLeNet

        为了在 MNIST 数据集上训练,我们构建了一个简化版的 GoogLeNet,包含三个 Inception 模块和一个全连接层。每个 Inception 模块由一个卷积层和一个最大池化层组成。简化的 GoogLeNet 模型如下:

import torch.nn as nnclass SimpleGoogLeNet(nn.Module):def __init__(self, num_classes=10):super(SimpleGoogLeNet, self).__init__()# 第一个 Inception 模块self.inception1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 第二个 Inception 模块self.inception2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 第三个 Inception 模块self.inception3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=2))# 分类器:全连接层 + Dropout 层self.fc = nn.Sequential(nn.Linear(128 * 3 * 3, 1024),nn.ReLU(),nn.Dropout(0.5),nn.Linear(1024, num_classes))def forward(self, x):x = self.inception1(x)x = self.inception2(x)x = self.inception3(x)x = x.view(x.size(0), -1)  # 展平输入x = self.fc(x)return x

3. 训练函数

        训练过程包括前向传播、反向传播和优化。我们将使用 Adam 优化器和 交叉熵损失 来训练模型:

import torch.optim as optim
from tqdm import tqdmdef train_epoch(model, device, train_loader, criterion, optimizer):model.train()running_loss = 0.0correct = 0total = 0with tqdm(train_loader, desc="Training", unit="batch", ncols=100) as pbar:for inputs, labels in pbar:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()pbar.set_postfix(loss=running_loss / (total // 64), accuracy=100 * correct / total)return running_loss / len(train_loader), 100 * correct / total

4. 训练脚本:train.py

        训练脚本将包括模型的定义、数据加载、训练过程等:

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from model import SimpleGoogLeNet  # 假设模型在 model.py 文件中def train_model():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = SimpleGoogLeNet().to(device)train_loader = get_data_loader(batch_size=64, train=True)criterion = torch.nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)epochs = 10for epoch in range(epochs):loss, accuracy = train_epoch(model, device, train_loader, criterion, optimizer)print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.2f}%")torch.save(model.state_dict(), "simplified_googlenet.pth")  # 保存模型if __name__ == '__main__':train_model()


测试部分:test.py

5. 测试函数

        在测试阶段,我们将使用 torch.no_grad() 禁用梯度计算,提高推理速度,并计算模型在测试集上的准确率:

def test_model(model, device, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Test Accuracy: {accuracy:.2f}%')

6. 测试脚本:test.py

        测试脚本将加载训练好的模型并对测试集进行评估:

import torch
from model import SimpleGoogLeNet  # 假设模型在 model.py 文件中
from torch.utils.data import DataLoader
from torchvision import datasets, transformsdef get_test_loader(batch_size=64):transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 正规化到 [-1, 1] 范围])test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)def main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = SimpleGoogLeNet().to(device)model.load_state_dict(torch.load("simplified_googlenet.pth"))  # 加载训练好的模型test_loader = get_test_loader(batch_size=64)test_model(model, device, test_loader)if __name__ == '__main__':main()

总结

        本文介绍了如何使用 PyTorch 实现简化版的 GoogLeNet,并将代码分为训练(train.py)和测试(test.py)部分。在训练脚本中,我们定义了一个简化版的 GoogLeNet,训练模型并保存训练结果。而在测试脚本中,我们加载训练好的模型并在测试集上进行评估。

        通过这些步骤,我们能够快速地实现一个高效的图像分类模型,并在 MNIST 数据集上进行训练与测试。

完整项目
GitHub - qxd-ljy/GoogLeNet-PyTorch: 使用PyTorch实现GooLeNet进行MINST图像分类使用PyTorch实现GooLeNet进行MINST图像分类. Contribute to qxd-ljy/GoogLeNet-PyTorch development by creating an account on GitHub.icon-default.png?t=O83Ahttps://github.com/qxd-ljy/GoogLeNet-PyTorchGitHub - qxd-ljy/GoogLeNet-PyTorch: 使用PyTorch实现GooLeNet进行MINST图像分类使用PyTorch实现GooLeNet进行MINST图像分类. Contribute to qxd-ljy/GoogLeNet-PyTorch development by creating an account on GitHub.icon-default.png?t=O83Ahttps://github.com/qxd-ljy/GoogLeNet-PyTorch

        希望这篇博客对你有所帮助,欢迎继续探索 PyTorch 和深度学习的更多应用!

关键字:微信管理办法_镇江网站制作咨询_aso优化教程_百度推广官网登录

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: