其他分享
首页 > 其他分享> > torch保存加载模型

torch保存加载模型

作者:互联网

目录

三个核心函数

torch.save() 
torch.load()
torch.nn.Module.load_state_dict()

状态字典定义

状态字典本质上就是普通的python字典。

只保存/加载模型参数(推荐做法)

# 保存模型参数
torch.save(model.state_dict(), PATH)  
# 加载模型参数并用于推理
model = MyModel()
model.load_static_dict(torch.load(PATH))
model.eval()

保存/加载整个模型

# 保存整个模型
torch.save(model, PATH)
# 加载整个模型
model = torch.load(PATH)
model.eval()

断点训练checkpoint使用

# 保存断点状态,保存的文件后缀一般是.tar。
torch.save({
  'epoch': epoch,
  'model_state_dict': model.state_dict(),
  'loss': loss,
  ...
}, PATH)

# 加载断点
model = MyModel()
optimizer = MyOptimizer()

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()
# model.eval() # 恢复断点之后直接推理也是可以的

同一个文件中保存多个模型

# 其实本质上跟checkpoint的使用是一样的
torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)
# 加载多个模型,本质上跟checkpoint也是一样的,保存文件后缀名也是.tar
modelA = MyModel()
modelB = MyModel()
optimizerA = MyOptimizer()
optimizerB = MyOptimizer()

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

用一个模型的参数来初始化另一个不同模型

# 保存模型参数
torch.save(modelA.state_dict(), PATH)
# 加载模型参数
modelB = MyModel()
modelB.load_state_dict(torch.load(PATH), strict=False)

不同设备保存/加载模型

标签:load,模型,torch,checkpoint,state,dict,model,加载
来源: https://www.cnblogs.com/chkplusplus/p/15987252.html