其他分享
首页 > 其他分享> > 神奇的torch.gather()

神奇的torch.gather()

作者:互联网

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor 

Gathers values along an axis specified by dim.

例如 原本一个tensor a是:

a[0][0]a[0][1]
a[1][0]a[1][1]


index tensor是:

 

 

jk
mn

现在,b=torch.gather(input=a, dim=0, index=index)

因此,将第0维的数据替换成index的数据,则b是:

a[j][0]a[k][1]
a[m][0]a[n][1]

如果,b=torch.gather(input=a, dim=1, index=index)

那么,b将会是

a[0][j]a[0][k]
a[1][m]a[1][n]

总之,dim是多少,就将那一维所查询的位置换成index里面对应位置上的数

标签:index,torch,tensor,dim,gather,input,神奇
来源: https://blog.csdn.net/weixin_45828771/article/details/119382751