其他分享
首页 > 其他分享> > Pytorch中torch.gather和torch.scatter函数理解

Pytorch中torch.gather和torch.scatter函数理解

作者:互联网

torch.gather()

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

参数解释:

示例1:

t = torch.tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1,  1],
        [ 4,  3]])

解释:

gather的意思是聚集和取,即从input这个张量中取元素,而index则对应所取元素的下标。如果dim=0,那么index中的数值表示行坐标,如果dim=1,那么index中的数值表示列坐标。另外,index的shape和output的shape应该要一致。

以上述示例来说就是:index的第一行对应输出的第一行,其元素[0,0]就是从t中的第一行的下标为0的位置取其元素

示例2:

t = torch.tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
                  [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
                  [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])
torch.gather(t, 0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]))
tensor([[0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
        [0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])

torch.scatter()

torch.scatter_(input, dim, index, src) → Tensor

参数解释:

示例1:

x = torch.rand(2, 5)
x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

解释:

scatter可以理解为gather的反操作,即用src中的元素去替换input中的元素,而index中的数值则对应input元素的下标。如果dim=0,那么index中的数值表示横坐标,如果dim=1,那么index中的数值表示纵坐标。另外,output的shape和input的shape是一致的。

src = torch.arange(1, 11).reshape((2, 5))
src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
------------------------------------------------
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])
------------------------------------------------
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])
------------------------------------------------
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
        [2.0000, 2.0000, 2.0000, 2.4600]])
------------------------------------------------
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
...            1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
        [2.0000, 2.0000, 2.0000, 3.2300]])

参考链接:

https://zhuanlan.zhihu.com/p/187401278
https://www.cnblogs.com/dogecheng/p/11938009.html
https://wmathor.com/index.php/archives/1457/

标签:index,gather,tensor,torch,Pytorch,0.0000,scatter,2.0000
来源: https://blog.csdn.net/weixin_42838061/article/details/117870261