当前位置: 首页> 娱乐> 八卦 > 二次开发信怎么写_2022年企业所得税税率_百度的网页地址_现在阳性最新情况

二次开发信怎么写_2022年企业所得税税率_百度的网页地址_现在阳性最新情况

时间:2025/7/10 18:16:54来源:https://blog.csdn.net/2302_79795489/article/details/142985740 浏览次数:0次
二次开发信怎么写_2022年企业所得税税率_百度的网页地址_现在阳性最新情况

一、现有网络模型的修改和使用

以torchvision里的models模块里的vgg16网络模型为例:

 加载现有网络模型:

 vgg16_false=torchvision.models.vgg16(weights=None) #表示不加载预训练权重

vgg16_true=torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT) #表示加载预训练权重

修改现有网络模型:

添加新层:

vgg16_true.classifier.add_module(name="add_linear",module=nn.Linear(1000,10))

#在classifier层级下,添加新模块(层),参数是(模块名,模块)

修改现有的层:

vgg_false.classifier[6]=nn.Linear(4096,10)

#修改classifier的第六模块(层),直接赋予新定义

import torchvision
from torch import nn
from torchvision.models import VGG16_Weights#加载现有网络模型
vgg16_false=torchvision.models.vgg16(weights=None)
vgg16_true=torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
print("ok")
# pretrained=False改为weight=None  pretrained=True改为weights=VGG16_Weights.DEFAULT
print(vgg16_true)#vgg_16最终分类是1000,而CIFAR10最终分类是10,如何更改?1、把out_features=1000改成=10,2、新增一层线性层
dataset=torchvision.datasets.CIFAR10(root="../dataset",transform=torchvision.transforms.ToTensor(),download=True)#修改现有网络模型
vgg16_false.classifier[6]=nn.Linear(4096,10) #法一:直接修改原来那个线性层,重新赋值即可
print(vgg16_false)vgg16_true.classifier.add_module(name="add_linear",module=nn.Linear(1000,10)) #法二:新增线性层模块(module)
print((vgg16_true))

 

二、网络模型的保存与读取

.pth是pytorch中常用的模型权重保存格式

保存方式1:模型结构+参数

torch.save(模型实例,文件保存路径)

     例子中,vgg16是VGG16模型类的一个实例对象,

  torchvision.models.vgg16 是一个函数,它返回一个已初始化的 VGG16 模型对象

保存方式2:仅参数(适合大模型)

torch.save(模型实例的所有参数vgg16.state_dict(),文件保存路径)

  state_dict():返回模型的所有参数

加载方式1:

model=torch.load(文件路径)

加载方式2:

model=torchvision.models.vgg16(weights=None) #新建不带参数的模型结构

model.load_state_dict(torch.load(文件路径)) #加载模型参数

 注意:自定义的模型要import模型类定义的文件夹,实例要能和类联系得上才行

model_save.py: 

import torch
import torchvision
from torch import nnvgg16=torchvision.models.vgg16(weights=None)
#保存方式1:模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")#保存方式2:模型参数(官方推荐,对于大模型很友好)
torch.save(vgg16.state_dict(),"vgg16_method2.pth") #把状态(参数)保存成字典#注意:用自定义的数据集时,要记得在load文件中import含有神经网络模型类定义的文件夹
class Xigua(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(3,64,5)def forward(self,x):x=self.conv1(x)return xxigua1=Xigua()
torch.save(xigua1,"xigua1.pth")

model_load.py: 

import torch
import torchvision.models
from torch import nn
from model_save import *#加载重现模型 方式1
model1=torch.load("vgg16_method1.pth")
# print(model1)#方式2
# model2=torch.load("vgg16_method2.pth") #这里的出来的只是模型的参数
# print(model2)
model2=torchvision.models.vgg16(weights=None) #新建模型结构,不带参数
model2.load_state_dict(torch.load("vgg16_method2.pth")) #给模型加载参数
# print(model2)#会报错,因为这个文件里没有原模型类的信息,只有实例是不可行的
model_xigua1=torch.load("xigua1.pth")
print(model_xigua1)

关键字:二次开发信怎么写_2022年企业所得税税率_百度的网页地址_现在阳性最新情况

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: