其他分享
首页 > 其他分享> > PyTorch中的矩阵乘法

PyTorch中的矩阵乘法

作者:互联网

1. 二维矩阵乘法 [公式]

[公式] , 其中 [公式][公式], 输出 [公式]的维度是[公式]。该函数[公式]一般只用来计算两个二维矩阵的矩阵乘法,而且不支持broadcast操作。

 

2. 三维带Batch矩阵乘法 [公式]

由于神经网络训练一般采用mini-batch,经常输入的是三维带batch矩阵,所以提供 [公式],其中 [公式][公式], 输出 [公式]的维度是 [公式]。该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。

3. "混合"矩阵乘法 [公式]

[公式] 支持broadcast操作,使用起来比较复杂,建议参考pytorch官方文档

 

 

 特别 ,针对多维数据 [公式]乘法,我们可以认为该 [公式]乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别是[公式][公式],那么我们可以认为 [公式] 乘法首先是进行后两位矩阵乘法得到[公式] ,然后分析两个参数的batch size分别是 [公式] 和 [公式] , 可以广播成为 [公式], 因此最终输出的维度是 [公式]

4. 矩阵逐元素(Element-wise)乘法 [公式]

[公式],其中 [公式] 乘数可以是标量也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast的即可,即该操作是支持broadcast操作的。

[公式] 是矩阵: 只要 [公式] 与 [公式] 的维度可以满足broadcast条件,就可以进行逐元素乘法操作,例如:

1 import torch
2 A = torch.randn(2,3,4)
3 B = torch.randn(3, 4)
4 print (torch.mul(A,b).shape) # 输出 torch.size([2,3,4)

5. 两个乘法操作符@和[公式] 

简单来说, @ 操作符可以执行矩阵乘法操作,类似 [公式] ; 而 [公式] 乘法操作可以执行逐元素矩阵乘法,使用方法类似 [公式]

 1 import torch
 2 
 3 x=torch.ones(3,2)
 4 print(x)
 5 
 6 y=torch.ones(3,2)+2
 7 print(y)
 8 
 9 z=torch.ones(2,1)
10 print(z)
11 
12 print(x*y@z)

 

 

 

 

参考:随笔1: PyTorch中矩阵乘法总结 - 知乎 (zhihu.com)

标签:torch,矩阵,broadcast,print,PyTorch,维度,乘法
来源: https://www.cnblogs.com/fan-faith/p/16140857.html