当前位置: 首页> 文旅> 旅游 > 国家企业信用信息公示系统下载_无锡网络公司哪家服务好_高质量关键词搜索排名_seo网站优化做什么

国家企业信用信息公示系统下载_无锡网络公司哪家服务好_高质量关键词搜索排名_seo网站优化做什么

时间:2025/9/2 16:59:54来源:https://blog.csdn.net/weixin_51904054/article/details/146327662 浏览次数:0次
国家企业信用信息公示系统下载_无锡网络公司哪家服务好_高质量关键词搜索排名_seo网站优化做什么
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transformsimport matplotlib.pyplot as pltfrom model import LeNetdef test_data_process():# 数据处理transform = transforms.Compose([transforms.Resize(28), transforms.ToTensor()])# 测试数据集test_dataset = datasets.FashionMNIST(root="./data",train=False,            # 是否作为训练集download=True,          # 是否需要下载transform=transform)    # 怎么做变化# 数据集的加载test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0)return test_loaderdef model_test(model, test_loader):# 指定设备device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 模型放置在设备中,在返回给模型model = model.to(device)# 初始化参数test_corrects = 0.0         # 测试一个样本,总计多少个样本是正确的test_num = 0                # 累计测试样例的数量with torch.no_grad():# 模型设置为验证模式model.eval()for batch_size, (images, labels) in enumerate(test_loader):if  batch_size % 1000:print("Batch size is {}".format(batch_size))print("images shape: {}".format(images.shape))print("labels shape: {}".format(labels.item()))# 将数据和标签放置在设备中images, labels = images.to(device), labels.to(device)# 将数据放置在模型中进行前向传播output = model(images)# 从输出结果中获取最大值pred_label = output.argmax(dim=1)# 累计测试集中的正样本数test_corrects += torch.sum(pred_label == labels)# 更新模型训练了多少数据test_num += images.size(0)# 计算准确率test_accuracy = test_corrects.double().item() / test_numreturn test_accuracydef show_image(tensor):plt.imshow(tensor.squeeze().numpy(), cmap='gray')plt.show()

假定数据集有6万张图片,并且都是彩色图片

test_dataset = datasets.FashionMNIST(root="./data",
                                        train=False,            # 是否作为训练集
                                        download=True,          # 是否需要下载
                                        transform=transform)    # 怎么做变化

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True, num_workers=0)

你使用 DataLoader 来加载数据集,并设置了 batch_size=64shuffle=True。这意味着:

  • Batch Size: 每个批次包含 64 张图片。
  • Shuffle: 在每个 epoch 开始时,数据会被打乱。
  • Num Workers: 设置为 0,意味着数据加载将在主进程中进行(而不是使用额外的进程)。

你有 60,000 张图片,那么总共有:

6万张图片,每64张为一捆数据。共计938捆,最后一捆不足64张。

with torch.no_grad():# 模型设置为验证模式model.eval()for batch_index, (images, labels) in enumerate(test_loader):
  • with torch.no_grad():表示如下的代码不要梯度

for batch_index, (images, labels) in enumerate(test_loader):
  • batch_index是从第0捆,一直会遍历到第938捆数据。
  • 第0捆中,
    • images的和标签的尺寸分别是:

images shape: torch.Size([64, 3, 28, 28])

labels shape: torch.Size([64])

[64, 3, 28, 28]表示有64张图片,每张图片有RGB,3个通道;图像的宽和高分别是28*28;


获取第3张图像

image_tensor = images[2]
print("图片的格式:{}".format(image_tensor.shape))

图片的格式:torch.Size([3, 28, 28])


获取第3张图像的标签

third_image_label = labels[2].item()

labels[2]是一个张量,item()之后才能获取到张量中的值。


显示一张图片

image_tensor = images[3]    # [1, 3, 28, 28]plt.imshow(image_tensor.squeeze().numpy(), cmap='gray')
plt.show()

如果图片是3通道的

import matplotlib.pyplot as plt# 创建一个形状为 [1, 3, 28, 28] 的张量
tensor = torch.randn(1, 3, 28, 28)# 使用 squeeze() 移除所有长度为1的维度
squeezed_tensor = tensor.squeeze()# 将张量转换为 NumPy 数组
numpy_array = squeezed_tensor.numpy()# 显示图像
plt.imshow(numpy_array.transpose(1, 2, 0))  # 需要调整通道顺序以适应 Matplotlib
plt.show()

如果图片是单通道的

import matplotlib.pyplot as plt# 创建一个形状为 [1, 1, 28, 28] 的张量
tensor = torch.randn(1, 1, 28, 28)# 使用 squeeze() 移除所有长度为1的维度
squeezed_tensor = tensor.squeeze()# 将张量转换为 NumPy 数组
numpy_array = squeezed_tensor.numpy()# 显示图像
plt.imshow(numpy_array, cmap='gray')
plt.show()

关键字:国家企业信用信息公示系统下载_无锡网络公司哪家服务好_高质量关键词搜索排名_seo网站优化做什么

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: