当前位置: 首页> 汽车> 维修 > 专业网站优化seo_适合发表个人文章的平台_今天上海最新新闻事件_广告联盟大全

专业网站优化seo_适合发表个人文章的平台_今天上海最新新闻事件_广告联盟大全

时间:2025/7/11 14:42:43来源:https://blog.csdn.net/jacke121/article/details/142210404 浏览次数: 0次
专业网站优化seo_适合发表个人文章的平台_今天上海最新新闻事件_广告联盟大全

目录

官方示例运行

修改分类数,训练自己的数据


官方示例运行

demo1.py

import timefrom urllib.request import urlopen
import torch
from PIL import Image
import timm# 打开图像
img = Image.open(urlopen('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'))# 创建模型,设置 pretrained=False 以避免从网络加载预训练权重
model = timm.create_model('mobilenetv4_conv_large.e600_r384_in1k', pretrained=False)# 手动加载本地的预训练权重
pretrained_weights_path = 'models/pytorch_model.bin'
model.load_state_dict(torch.load(pretrained_weights_path))# 设置模型为评估模式
model = model.eval()
model = model.cuda()# 获取模型特定的变换(归一化、调整大小等)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)for i in range(10):img = transforms(img).cuda()start=time.time()output = model(img.unsqueeze(0))  # unsqueeze 将单张图片扩展为批量大小为1print(img.shape,output.shape,'time', time.time()-start)# 获取 top 5 结果
top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)# 打印结果
print(top5_probabilities)
print(top5_class_indices)

gpu 384 平均7ms

修改分类数,训练自己的数据

import torch
import timm
from torch import nn# 创建模型,设置 pretrained=False 以避免从网络加载预训练权重
model = timm.create_model('mobilenetv4_conv_large.e600_r384_in1k', pretrained=False)# 修改分类数为10
num_classes = 10
model.classifier = nn.Linear(model.classifier.in_features, num_classes)# 手动加载本地的预训练权重
pretrained_weights_path = 'models/pytorch_model.bin'
state_dict = torch.load(pretrained_weights_path)
# 如果state_dict中包含分类层的权重,需要删除
state_dict.pop('classifier.weight', None)
state_dict.pop('classifier.bias', None)
model.load_state_dict(state_dict, strict=False)# 检查是否成功加载
print("Model loaded with custom weights.")# 下面是加载自己的数据并进行训练的示例代码from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据预处理
transform = transforms.Compose([transforms.Resize((384, 384)),  # 模型输入尺寸transforms.ToTensor(),
])# 加载训练数据和验证数据
train_dataset = datasets.ImageFolder('path_to_train_data', transform=transform)
val_dataset = datasets.ImageFolder('path_to_val_data', transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)# 训练模型
num_epochs = 10
model.train()
for epoch in range(num_epochs):running_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")# 验证模型
model.eval()
correct = 0
total = 0
with torch.no_grad():for inputs, labels in val_loader:outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print(f"Accuracy: {100 * correct / total}%")

关键字:专业网站优化seo_适合发表个人文章的平台_今天上海最新新闻事件_广告联盟大全

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: