pytorch中网络参数的默认精度
作者:互联网
pytorch默认使用单精度float32训练模型,其主要原因为:使用float16训练模型,模型效果会有损失,而使用double(float64)会有2倍的内存压力,且不会带来太多的精度提升,因此默认使用单精度float32训练模型。
由于输入类型不一致导致报错:
PyTorch:expected scalar type Float but found Double
表明代码中网络参数类型不统一。
pytorch如何更改默认单精度float32训练模型,而改为torch.float64对模型进行训练呢?
解决办法:把模型的权重参数数据类型和输入数据类型全部设置为torch.float64。
使用torch.set_default_dtype(torch.float64)把模型参数转化为float64,或使用net = net.double()
输入类型使用tensor.type(torch.float64)将输入数据类型转换为torch.float64。
标签:单精度,float64,模型,torch,默认,pytorch,float32,精度 来源: https://www.cnblogs.com/jiangkejie/p/16415587.html