当前位置: 首页> 文旅> 美景 > Pytorch:加载断点(pth)权重参数

Pytorch:加载断点(pth)权重参数

时间:2025/7/11 7:44:18来源:https://blog.csdn.net/Ethan_Rich/article/details/141192218 浏览次数:0次

一、保存的模型参数及权重

#保存模型
torch.save(model_object,'resnet.pth')
#加载模型
model=torch.load('resnet.pth')

二、仅保存模型的权重


torch.save(my_resnet.state_dict(),"resnet.pth")resnet_model.load_state_dict(torch.load("resnet.pth"))

三、仅加载部分参数

resnet152=models.resnet152(pretrained=True)
pretrained_dict=resnet152.state_dict()model_dict=model.state_dict()#将pretrained_dict里不属于model_dict的键去除掉
pretrained_dict={k:v for k,v in pretrained_dict.items() if k in model_dict}#更新现有的model_dict
model_dict.update(pretrained_dict)
#加载真正需要的state_dict
model.load_state_dict(model_dict)

四、微调预训练模型 


resnet=torchvision.models.resnet152(pretrained=True)resnet.fc=torch.nn.Linear(2048,10)

关键字:Pytorch:加载断点(pth)权重参数

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

责任编辑: