其他分享
首页 > 其他分享> > torch.gather

torch.gather

作者:互联网

函数定义

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Gathers values along an axis specified by dim.

对于一个3-D的张量,输出按照以下公式被指定为:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

函数参数

函数参数说明

例子

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

标签:index,gather,tensor,dim,torch,input,out
来源: https://www.cnblogs.com/zjuhaohaoxuexi/p/15595584.html