其他分享
首页 > 其他分享> > pytorch使用多GPU

pytorch使用多GPU

作者:互联网

直接修改dict的key当然也是可以的,不会影响模型。

但是逻辑上,事实上DataParallel也是一个Pytorch的nn.Module,只是这个类其中有一个module的变量用来保存传入的实际模型。

nn.DataParallel(m)

这句返回的已经不是原始的m了,而是一个DataParallel,原始的m保存在DataParallel的module变量里面。

 

所以,逻辑上有两个方法:

  1. 保存的时候直接取出原始的m:
torch.save(m.module.state_dict(), path)

2. 或者载入的时候用一个DataParallel载入,再取出原始模型:

m=nn.DataParallel(Resnet18(), device_ids=[0,1,2])
m.load_state_dict(torch.load(path))
m=m.module

这样逻辑上更好看一点。



标签:load,原始,nn,module,DataParallel,pytorch,dict,使用,GPU
来源: https://www.cnblogs.com/ccfco/p/15184213.html