其他分享
首页 > 其他分享> > pytorch中网络参数的默认精度

pytorch中网络参数的默认精度

作者:互联网

pytorch默认使用单精度float32训练模型,其主要原因为:使用float16训练模型,模型效果会有损失,而使用double(float64)会有2倍的内存压力,且不会带来太多的精度提升,因此默认使用单精度float32训练模型。

 

由于输入类型不一致导致报错:

PyTorch:expected scalar type Float but found Double

表明代码中网络参数类型不统一。

pytorch如何更改默认单精度float32训练模型,而改为torch.float64对模型进行训练呢?

解决办法:把模型的权重参数数据类型和输入数据类型全部设置为torch.float64。

使用torch.set_default_dtype(torch.float64)把模型参数转化为float64,或使用net = net.double()

输入类型使用tensor.type(torch.float64)将输入数据类型转换为torch.float64。

 

 

标签:单精度,float64,模型,torch,默认,pytorch,float32,精度
来源: https://www.cnblogs.com/jiangkejie/p/16415587.html