Pytorch nn.DataParallel()的简单用法
作者:互联网
简单来说就是使用单机多卡进行训练。
一般来说我们看到的代码是这样的:
net = XXXNet()
net = nn.DataParallel(net)
这样就可以让模型在全部GPU上训练。
方法定义:
class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
- module:需要进行并行的模型
- device_ids:并行所用的GPU。可以是int列表也可以是device对象,默认不写就是使用全部GPU
- output_device:输出所用的GPU。可以是GPU id或device对象,默认不写就是第一张(device_ids[0])
参考
https://www.aiuai.cn/aifarm1340.html
https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/
标签:nn,ids,DataParallel,Pytorch,device,GPU,net 来源: https://blog.csdn.net/qq_40714949/article/details/115299128