其他分享
首页 > 其他分享> > torch.cat() 维度解析

torch.cat() 维度解析

作者:互联网

import torch

a = torch.randn((2, 3, 4))
print(a)

res:
tensor([[[ 0.2615,  0.9965,  0.5935, -2.4657],
         [-2.0211,  0.5055,  0.3128, -0.8843],
         [-1.3269, -1.0438,  0.3159, -0.8572]],

        [[ 0.7990,  0.2023, -0.1174, -0.8619],
         [ 0.5947,  0.9148,  0.6587,  1.6023],
         [ 1.0809, -1.2143,  0.8764,  0.7159]]])

dim=0. torch.shape=(4,3,4)

import torch

a = torch.randn((2, 3, 4))
# print(a)
b = torch.cat((a, a), dim=0)
print(b)

res:
tensor([[[-0.6508,  0.1268, -0.8134, -0.7238],
         [-0.0616, -0.7403,  0.3288, -0.8408],
         [ 0.3305, -2.1410, -1.7286,  0.7594]],

        [[ 1.5005,  0.3792, -0.8897,  0.3702],
         [ 1.1504, -0.1261, -0.3419, -0.6803],
         [ 1.3511,  0.5674,  0.6122,  1.0454]],

        [[-0.6508,  0.1268, -0.8134, -0.7238],
         [-0.0616, -0.7403,  0.3288, -0.8408],
         [ 0.3305, -2.1410, -1.7286,  0.7594]],

        [[ 1.5005,  0.3792, -0.8897,  0.3702],
         [ 1.1504, -0.1261, -0.3419, -0.6803],
         [ 1.3511,  0.5674,  0.6122,  1.0454]]])

dim=1 torch.shape=(2,6,4)

import torch

a = torch.randn((2, 3, 4))
# print(a)
b = torch.cat((a, a), dim=0)
# print(b)
c = torch.cat((a, a), dim=1)
print(c)

res:
tensor([[[ 0.7432, -1.4758,  0.5132, -0.6230],
         [ 0.5078, -2.2970, -0.1246,  0.4064],
         [-0.1692, -1.6422, -0.1521, -1.3583],
         [ 0.7432, -1.4758,  0.5132, -0.6230],
         [ 0.5078, -2.2970, -0.1246,  0.4064],
         [-0.1692, -1.6422, -0.1521, -1.3583]],

        [[ 0.6635, -0.7000,  0.0392, -0.9496],
         [-0.2973,  0.8815,  1.1791, -0.8074],
         [ 0.9554,  0.5348, -0.6834,  0.5662],
         [ 0.6635, -0.7000,  0.0392, -0.9496],
         [-0.2973,  0.8815,  1.1791, -0.8074],
         [ 0.9554,  0.5348, -0.6834,  0.5662]]])

 

标签:dim,randn,tensor,torch,cat,print,维度
来源: https://blog.csdn.net/weixin_40823740/article/details/115374854