torch.bmm()解读
作者:互联网
函数作用
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m) 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,输出维度 (b,h,m)
代码示例
>>> c=torch.randn((2,5))
>>> print(c)
tensor([[ 1.0559, -0.3533, 0.5194, 0.9526, -0.2483],
[-0.1293, 0.4809, -0.5268, -0.3673, 0.0666]])
>>> d=torch.reshape(c,(5,2))
>>> print(d)
tensor([[ 1.0559, -0.3533],
[ 0.5194, 0.9526],
[-0.2483, -0.1293],
[ 0.4809, -0.5268],
[-0.3673, 0.0666]])
>>> e=torch.bmm(c,d)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
当tensor维度为2时会报错!
>>> ccc=torch.randn((1,2,2,5))
>>> ddd=torch.randn((1,2,5,2))
>>> e=torch.bmm(ccc,ddd)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 1: expected 3D tensor, got 4D at /opt/conda/conda-bld/pytorch_1535490206202/work/aten/src/TH/generic/THTensorMath.cpp:2304
维度为4时也会报错!
>>> cc=torch.randn((2,2,5))
>>>print(cc)
tensor([[[ 1.4873, -0.7482, -0.6734, -0.9682, 1.2869],
[ 0.0550, -0.4461, -0.1102, -0.0797, -0.8349]],
[[-0.6872, 1.1920, -0.9732, 0.4580, 0.7901],
[ 0.3035, 0.2022, 0.8815, 0.9982, -1.1892]]])
>>>dd=torch.reshape(cc,(2,5,2))
>>> print(dd)
tensor([[[ 1.4873, -0.7482],
[-0.6734, -0.9682],
[ 1.2869, 0.0550],
[-0.4461, -0.1102],
[-0.0797, -0.8349]],
[[-0.6872, 1.1920],
[-0.9732, 0.4580],
[ 0.7901, 0.3035],
[ 0.2022, 0.8815],
[ 0.9982, -1.1892]]])
>>>e=torch.bmm(cc,dd)
>>> print(e)
tensor([[[ 2.1787, -1.3931],
[ 0.3425, 1.0906]],
[[-0.5754, -1.1045],
[-0.6941, 3.0161]]])
>>> e.size()
torch.Size([2, 2, 2])
正确!
参考:https://blog.csdn.net/qq_40178291/article/details/100302375标签:randn,tensor,cc,torch,bmm,解读,print 来源: https://www.cnblogs.com/chentiao/p/16316266.html