机器学习系列——关于torch.nn.DataParallel的测试
作者:互联网
0 前言
前几天把服务器上训练好的模型转移到Jeston开发板上跑测试,加载模型时报错:
no module named "model"
后来经过一番折腾,终于搞明白原因。是因为在服务器上跑训练时使用了torch.nn.DataParallel进行加速,所以保存后的模型在Jeston开发板上进行torch.load()时报错。
今天有时间了解了一下torch.nn.DataParallel这个模型,并进行简单测试。
1 torch.nn.DataParallel
参考https://zhuanlan.zhihu.com/p/102697821,讲得很细。
2 实际测试
下面摘出关键代码段,完整代码参见https://github.com/GaoZiqiang/fine_grained_Multiview_Detection/blob/master/utils.py
2.1 模型训练阶段
# 模型初始化
net = load_model(model_name='resnet50_pmg', pretrain=True, require_grad=True)
# 使用DataParallel加速训练
netp = torch.nn.DataParallel(net, device_ids=[0])
# 保存模型参数
torch.save(netp.module.state_dict(),"D:\gaoziqiang\model_netp_1.pth")
2.2 测试阶段
model_path = "D:\gaoziqiang\model_netp_1.pth"
### 加载netp模型的原型
net = load_model(model_name='resnet50_pmg', pretrain=True, require_grad=True)
### DataParallel化
net = nn.DataParallel(net,device_ids=[0])
### 一定要先net.module
net = net.module
### 加载state_dicts
model_PMG_state_dicts = torch.load(model_path)
### 使用state_dicts实例化net
net.load_state_dict(model_PMG_state_dicts)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
criterion = nn.CrossEntropyLoss()
### 调用test主函数
test(net,criterion,batch_size=2)
3 参考材料
标签:nn,torch,DataParallel,net,model,### 来源: https://blog.csdn.net/qq_33429968/article/details/117535719