其他分享
首页 > 其他分享> > pytorch 数据类型

pytorch 数据类型

作者:互联网

pytorch中的数据类型

 

 

import torch

a=torch.randn(2,3)
b=a.type()
print(b)

#检验是否是该数据类型
print(isinstance(a,torch.FloatTensor))


print(isinstance(a,torch.cuda.FloatTensor))
a=a.cuda()
print(isinstance(a,torch.cuda.FloatTensor))


基本数据类型的生成

#生成一个Tensor,数值为1.1
a=torch.Tensor([1.1])
print(a)

#生成一个二维的Tensor,数值为1.1,2.2
b=torch.Tensor([1.1,2.2])
print(b)

#生成一个一维的Tensor,Tensor的值由random初始化
c=torch.FloatTensor(1)
print(c)

#生成er个一维的Tensor,Tensor的值random初始化
d=torch.FloatTensor(2)
print(d)

#由np生成一个Tensor,二维数值为1
e=np.ones(2)
print(e)
f=torch.from_numpy(e)
print(f)

 

dim、size与shape的区别

a=torch.ones(4,3)

print(a)
print(a.dim())
print(a.size())
print(a.shape)


>>>tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
2
torch.Size([4, 3])
torch.Size([4, 3])

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

1

标签:Tensor,FloatTensor,torch,数据类型,pytorch,print,1.1
来源: https://www.cnblogs.com/Manuel/p/10821026.html