其他分享
首页 > 其他分享> > pytorch 保存训练好的模型

pytorch 保存训练好的模型

作者:互联网

1 保存和加载整个模型 

torch.save(model_object, 'model.pth') 
model = torch.load('model.pth')

2 仅保存和加载模型参数

torch.save(model_obj.state_dict(), 'params.pth')  
model_obj.load_state_dict(torch.load('params.pth'))  

3 选择保存网络中的一部分参数或者额外保存其余的参数

torch.save({'state_dict': net.state_dict(), 'linear1':net.linear1.state_dict(),
            'optimizer': optimizer.state_dict(),'num_epoch':num_epochs },
            'detail.pth')
model = torch.load('detail.pth')
net = DNN(num_input,num_hidden1,num_hidden2,num_output)
net.load_state_dict(model['state_dict'])

 

参考:

[日常] PyTorch 预训练模型,保存,读取和更新模型参数以及多 GPU 训练模型

 

标签:训练,pth,模型,torch,state,pytorch,num,dict,model
来源: https://www.cnblogs.com/BlairGrowing/p/15981277.html