其他分享
首页 > 其他分享> > Pytorch nn.DataParallel()的简单用法

Pytorch nn.DataParallel()的简单用法

作者:互联网

简单来说就是使用单机多卡进行训练。
一般来说我们看到的代码是这样的:

net = XXXNet()
net = nn.DataParallel(net)

这样就可以让模型在全部GPU上训练。

方法定义:

class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)

参考

https://www.aiuai.cn/aifarm1340.html
https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch-nn/

标签:nn,ids,DataParallel,Pytorch,device,GPU,net
来源: https://blog.csdn.net/qq_40714949/article/details/115299128