其他分享
首页 > 其他分享> > 机器学习系列——关于torch.nn.DataParallel的测试

机器学习系列——关于torch.nn.DataParallel的测试

作者:互联网

0 前言

前几天把服务器上训练好的模型转移到Jeston开发板上跑测试,加载模型时报错:

no module named "model"

后来经过一番折腾,终于搞明白原因。是因为在服务器上跑训练时使用了torch.nn.DataParallel进行加速,所以保存后的模型在Jeston开发板上进行torch.load()时报错。
今天有时间了解了一下torch.nn.DataParallel这个模型,并进行简单测试。

1 torch.nn.DataParallel

参考https://zhuanlan.zhihu.com/p/102697821,讲得很细。

2 实际测试

下面摘出关键代码段,完整代码参见https://github.com/GaoZiqiang/fine_grained_Multiview_Detection/blob/master/utils.py

2.1 模型训练阶段

# 模型初始化
net = load_model(model_name='resnet50_pmg', pretrain=True, require_grad=True)
# 使用DataParallel加速训练
netp = torch.nn.DataParallel(net, device_ids=[0])

# 保存模型参数
torch.save(netp.module.state_dict(),"D:\gaoziqiang\model_netp_1.pth")

2.2 测试阶段

	model_path = "D:\gaoziqiang\model_netp_1.pth"

    ### 加载netp模型的原型
    net = load_model(model_name='resnet50_pmg', pretrain=True, require_grad=True)

    ### DataParallel化
    net = nn.DataParallel(net,device_ids=[0])
    ### 一定要先net.module
    net = net.module

    ### 加载state_dicts
    model_PMG_state_dicts = torch.load(model_path)
    ### 使用state_dicts实例化net
    net.load_state_dict(model_PMG_state_dicts)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    
    criterion = nn.CrossEntropyLoss()
	### 调用test主函数
    test(net,criterion,batch_size=2)

3 参考材料

标签:nn,torch,DataParallel,net,model,###
来源: https://blog.csdn.net/qq_33429968/article/details/117535719