pytorch多GPU计算
作者:互联网
pytorch多GPU计算
如果正确安装了NVIDIA驱动,我们可以通过在命令行输入nvidia-smi命令来查看当前计算机上的全部GPU
定义一个模型:
import torch
net = torch.nn.Linear(10, 1).cuda()
net
output:
Linear(in_features=10, out_features=1, bias=True)
要想使用PyTorch进行多GPU计算,最简单的方法是直接用torch.nn.DataParallel将模型wrap一下即可:
net = torch.nn.DataParallel(net)
net
output:
DataParallel(
(module): Linear(in_features=10, out_features=1, bias=True)
)
这时,默认所有存在的GPU都会被使用。
指定使用的GPU可以使用以下方式:
torch.nn.DataParallel(net, device_ids=[0, 1])
这表示只使用0、1号显卡
多GPU模型的保存与加载
torch.save(net.state_dict(), "./test_model.pt")
加载模型前我们一般要先进行一下模型定义,此时的new_net并没有使用多GPU:
new_net = torch.nn.Linear(10, 1)
new_net.load_state_dict(torch.load("./test_model.pt"))
报错
RuntimeError: Error(s) in loading state_dict for Linear:
Missing key(s) in state_dict: "weight", "bias".
Unexpected key(s) in state_dict: "module.weight", "module.bias".
事实上DataParallel也是一个nn.Module,只是这个类其中有一个module就是传入的实际模型。因此当我们调用DataParallel后,模型结构变了。所以直接加载肯定会报错的,因为模型结构对不上。
所以正确的方法是保存的时候只保存net.module:
torch.save(net.module.state_dict(), "./test_model.pt")
new_net.load_state_dict(torch.load("./test_model.pt")) # 加载成功
或者先将new_net用DataParallel包括以下再用上面报错的方法进行模型加载:
torch.save(net.state_dict(), "./test_model.pt")
new_net = torch.nn.Linear(10, 1)
new_net = torch.nn.DataParallel(new_net)
new_net.load_state_dict(torch.load("./test_model.pt")) # 加载成功
推荐用第一种方法,因为可以按照普通的加载方法进行正确加载
标签:torch,state,DataParallel,pytorch,dict,计算,new,GPU,net 来源: https://blog.csdn.net/qq_42255269/article/details/112107920