pytorch 数据类型
作者:互联网
pytorch中的数据类型
import torch a=torch.randn(2,3) b=a.type() print(b) #检验是否是该数据类型 print(isinstance(a,torch.FloatTensor)) print(isinstance(a,torch.cuda.FloatTensor)) a=a.cuda() print(isinstance(a,torch.cuda.FloatTensor))
基本数据类型的生成
#生成一个Tensor,数值为1.1 a=torch.Tensor([1.1]) print(a) #生成一个二维的Tensor,数值为1.1,2.2 b=torch.Tensor([1.1,2.2]) print(b) #生成一个一维的Tensor,Tensor的值由random初始化 c=torch.FloatTensor(1) print(c) #生成er个一维的Tensor,Tensor的值random初始化 d=torch.FloatTensor(2) print(d) #由np生成一个Tensor,二维数值为1 e=np.ones(2) print(e) f=torch.from_numpy(e) print(f)
dim、size与shape的区别
a=torch.ones(4,3) print(a) print(a.dim()) print(a.size()) print(a.shape) >>>tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) 2 torch.Size([4, 3]) torch.Size([4, 3])
1
标签:Tensor,FloatTensor,torch,数据类型,pytorch,print,1.1 来源: https://www.cnblogs.com/Manuel/p/10821026.html