其他分享
首页 > 其他分享> > pytorch多GPU计算

pytorch多GPU计算

作者:互联网

pytorch多GPU计算

如果正确安装了NVIDIA驱动,我们可以通过在命令行输入nvidia-smi命令来查看当前计算机上的全部GPU

定义一个模型:

import torch
net = torch.nn.Linear(10, 1).cuda()
net

output:

Linear(in_features=10, out_features=1, bias=True)

要想使用PyTorch进行多GPU计算,最简单的方法是直接用torch.nn.DataParallel将模型wrap一下即可:

net = torch.nn.DataParallel(net)
net

output:

DataParallel(
  (module): Linear(in_features=10, out_features=1, bias=True)
)

这时,默认所有存在的GPU都会被使用。

指定使用的GPU可以使用以下方式:

torch.nn.DataParallel(net, device_ids=[0, 1])

这表示只使用0、1号显卡

多GPU模型的保存与加载

torch.save(net.state_dict(), "./test_model.pt")

加载模型前我们一般要先进行一下模型定义,此时的new_net并没有使用多GPU:

new_net = torch.nn.Linear(10, 1)
new_net.load_state_dict(torch.load("./test_model.pt"))

报错

RuntimeError: Error(s) in loading state_dict for Linear:
    Missing key(s) in state_dict: "weight", "bias". 
    Unexpected key(s) in state_dict: "module.weight", "module.bias". 

事实上DataParallel也是一个nn.Module,只是这个类其中有一个module就是传入的实际模型。因此当我们调用DataParallel后,模型结构变了。所以直接加载肯定会报错的,因为模型结构对不上。

所以正确的方法是保存的时候只保存net.module:

torch.save(net.module.state_dict(), "./test_model.pt")
new_net.load_state_dict(torch.load("./test_model.pt")) # 加载成功

或者先将new_net用DataParallel包括以下再用上面报错的方法进行模型加载:

torch.save(net.state_dict(), "./test_model.pt")
new_net = torch.nn.Linear(10, 1)
new_net = torch.nn.DataParallel(new_net)
new_net.load_state_dict(torch.load("./test_model.pt")) # 加载成功

推荐用第一种方法,因为可以按照普通的加载方法进行正确加载

标签:torch,state,DataParallel,pytorch,dict,计算,new,GPU,net
来源: https://blog.csdn.net/qq_42255269/article/details/112107920