一、现有网络模型的修改和使用
以torchvision里的models模块里的vgg16网络模型为例:
加载现有网络模型:
vgg16_false=torchvision.models.vgg16(weights=None) #表示不加载预训练权重
vgg16_true=torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT) #表示加载预训练权重
修改现有网络模型:
添加新层:
vgg16_true.classifier.add_module(name="add_linear",module=nn.Linear(1000,10))
#在classifier层级下,添加新模块(层),参数是(模块名,模块)
修改现有的层:
vgg_false.classifier[6]=nn.Linear(4096,10)
#修改classifier的第六模块(层),直接赋予新定义
import torchvision
from torch import nn
from torchvision.models import VGG16_Weights#加载现有网络模型
vgg16_false=torchvision.models.vgg16(weights=None)
vgg16_true=torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)
print("ok")
# pretrained=False改为weight=None pretrained=True改为weights=VGG16_Weights.DEFAULT
print(vgg16_true)#vgg_16最终分类是1000,而CIFAR10最终分类是10,如何更改?1、把out_features=1000改成=10,2、新增一层线性层
dataset=torchvision.datasets.CIFAR10(root="../dataset",transform=torchvision.transforms.ToTensor(),download=True)#修改现有网络模型
vgg16_false.classifier[6]=nn.Linear(4096,10) #法一:直接修改原来那个线性层,重新赋值即可
print(vgg16_false)vgg16_true.classifier.add_module(name="add_linear",module=nn.Linear(1000,10)) #法二:新增线性层模块(module)
print((vgg16_true))
二、网络模型的保存与读取
.pth是pytorch中常用的模型权重保存格式
保存方式1:模型结构+参数
torch.save(模型实例,文件保存路径)
例子中,vgg16是VGG16模型类的一个实例对象,
torchvision.models.vgg16
是一个函数,它返回一个已初始化的 VGG16 模型对象保存方式2:仅参数(适合大模型)
torch.save(模型实例的所有参数vgg16.state_dict(),文件保存路径)
state_dict()
:返回模型的所有参数
加载方式1:
model=torch.load(文件路径)
加载方式2:
model=torchvision.models.vgg16(weights=None) #新建不带参数的模型结构
model.load_state_dict(torch.load(文件路径)) #加载模型参数
注意:自定义的模型要import模型类定义的文件夹,实例要能和类联系得上才行
model_save.py:
import torch
import torchvision
from torch import nnvgg16=torchvision.models.vgg16(weights=None)
#保存方式1:模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")#保存方式2:模型参数(官方推荐,对于大模型很友好)
torch.save(vgg16.state_dict(),"vgg16_method2.pth") #把状态(参数)保存成字典#注意:用自定义的数据集时,要记得在load文件中import含有神经网络模型类定义的文件夹
class Xigua(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Conv2d(3,64,5)def forward(self,x):x=self.conv1(x)return xxigua1=Xigua()
torch.save(xigua1,"xigua1.pth")
model_load.py:
import torch
import torchvision.models
from torch import nn
from model_save import *#加载重现模型 方式1
model1=torch.load("vgg16_method1.pth")
# print(model1)#方式2
# model2=torch.load("vgg16_method2.pth") #这里的出来的只是模型的参数
# print(model2)
model2=torchvision.models.vgg16(weights=None) #新建模型结构,不带参数
model2.load_state_dict(torch.load("vgg16_method2.pth")) #给模型加载参数
# print(model2)#会报错,因为这个文件里没有原模型类的信息,只有实例是不可行的
model_xigua1=torch.load("xigua1.pth")
print(model_xigua1)