torch.cat() 维度解析
作者:互联网
import torch
a = torch.randn((2, 3, 4))
print(a)
res:
tensor([[[ 0.2615, 0.9965, 0.5935, -2.4657],
[-2.0211, 0.5055, 0.3128, -0.8843],
[-1.3269, -1.0438, 0.3159, -0.8572]],
[[ 0.7990, 0.2023, -0.1174, -0.8619],
[ 0.5947, 0.9148, 0.6587, 1.6023],
[ 1.0809, -1.2143, 0.8764, 0.7159]]])
dim=0. torch.shape=(4,3,4)
import torch
a = torch.randn((2, 3, 4))
# print(a)
b = torch.cat((a, a), dim=0)
print(b)
res:
tensor([[[-0.6508, 0.1268, -0.8134, -0.7238],
[-0.0616, -0.7403, 0.3288, -0.8408],
[ 0.3305, -2.1410, -1.7286, 0.7594]],
[[ 1.5005, 0.3792, -0.8897, 0.3702],
[ 1.1504, -0.1261, -0.3419, -0.6803],
[ 1.3511, 0.5674, 0.6122, 1.0454]],
[[-0.6508, 0.1268, -0.8134, -0.7238],
[-0.0616, -0.7403, 0.3288, -0.8408],
[ 0.3305, -2.1410, -1.7286, 0.7594]],
[[ 1.5005, 0.3792, -0.8897, 0.3702],
[ 1.1504, -0.1261, -0.3419, -0.6803],
[ 1.3511, 0.5674, 0.6122, 1.0454]]])
dim=1 torch.shape=(2,6,4)
import torch
a = torch.randn((2, 3, 4))
# print(a)
b = torch.cat((a, a), dim=0)
# print(b)
c = torch.cat((a, a), dim=1)
print(c)
res:
tensor([[[ 0.7432, -1.4758, 0.5132, -0.6230],
[ 0.5078, -2.2970, -0.1246, 0.4064],
[-0.1692, -1.6422, -0.1521, -1.3583],
[ 0.7432, -1.4758, 0.5132, -0.6230],
[ 0.5078, -2.2970, -0.1246, 0.4064],
[-0.1692, -1.6422, -0.1521, -1.3583]],
[[ 0.6635, -0.7000, 0.0392, -0.9496],
[-0.2973, 0.8815, 1.1791, -0.8074],
[ 0.9554, 0.5348, -0.6834, 0.5662],
[ 0.6635, -0.7000, 0.0392, -0.9496],
[-0.2973, 0.8815, 1.1791, -0.8074],
[ 0.9554, 0.5348, -0.6834, 0.5662]]])
标签:dim,randn,tensor,torch,cat,print,维度 来源: https://blog.csdn.net/weixin_40823740/article/details/115374854