einsum函数介绍-张量常用操作
作者:互联网
pytorch文档说明:\(torch.einsum(\)\(*equation*\)$, $$operands*$$)$ 使用基于爱因斯坦求和约定的符号,将输入operands的元素沿指定的维数求和。einsum允许计算许多常见的多维线性代数阵列运算,方法是基于爱因斯坦求和约定以简写格式表示它们。主要是省略了求和号,总体思路是在箭头左边用一些下标标记输入operands的每个维度,并在箭头右边定义哪些下标是输出的一部分。通过将operands元素与下标不属于输出的维度的乘积求和来计算输出。其方便之处在于可以直接通过求和公式写出运算代码。**
两个基本概念,自由索引(Free indices)和求和索引(Summation indices):
- 自由索引,出现在箭头右边的索引
- 求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,
单操作数
获取对角线元素diagonal
einsum 可以不做求和。举个例子,获取二维方阵的对角线元素,结果放入一维向量。
\[A_i = B_{ii} \]上面,A 是一维向量,B 是二维方阵。使用 einsum 记法,可以写作 ii->i
torch.einsum('ii->i', torch.randn(4, 4))
# 以下操作互相等价
a = torch.randn(4,4)
c = torch.einsum('ii->i', a)
c = torch.diagonal(a, 0)
迹trace
求解矩阵的迹(trace),即对角线元素的和。
\[t = \Sigma_{i=1}^{n} A_{ii} \]t 是常量,A 是二维方阵。按照前面的做法,省略 ΣΣ,左右两边对调,省去矩阵和 t,剩下的就是ii->
或省略箭头ii
torch.einsum('ii', torch.randn(4, 4))
矩阵转置
\[A_{ij} = B_{ji} \]A 和 B 都是二维方阵。einsum 可以表达为 ij->ji
。
torch.einsum('ij -> ji',a)
pytorch 中,还支持省略前面的维度。比如,只转置最后两个维度,可以表达为 ...ij->...ji
。下面展示了一个含有四个二维矩阵的三维矩阵,转置三维矩阵中的每个二维矩阵。
A = torch.randn(2, 3, 4, 5)
torch.einsum('...ij->...ji', A).shape
# torch.Size([2, 3, 5, 4])
# 等价操作
A.permute(0,1,3,2)
A.transpose(2,3)
求和
\[b=\sum_{i} \sum_{j} A_{i j}=A_{i j} \]a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->', [a])
tensor(15.)
列求和:
\[b_{j}=\sum_{i} A_{i j}=A_{i j} \]a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->j', [a])
tensor([ 3., 5., 7.])
# 等价操作
torch.sum(a, 0) # (dim参数0) means the dimension or dimensions to reduce.
双操作数
矩阵乘法
\[A_{ij} = \Sigma_{k=1}^{n} B_{ik} C_{kj} \]第一个学习的 einsum 表达式是,ik,kj->ij
。前面提到过,爱因斯坦求和记法可以理解为懒人求和记法。将上述公式中的 ΣΣ 去掉,并且将左右两边对调一下,省去矩阵之后,剩下的就是 ik,kj->ij
了。
torch.einsum('ik,kj->ij', a, b)
# 可用两个矩阵测试以下矩阵乘法操作互相等价
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b
矩阵-向量相乘
\[c_{i}=\sum_{k} A_{i k} b_{k}=A_{i k} b_{k} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)
torch.einsum('ik,k->i', [a, b])
tensor([ 5., 14.])
批量矩阵乘 batch matrix multiplication
\[C_{bik}=\sum_{k} A_{bij} B_{bjk}=A_{bij} B_{bjk} \]>>> As = torch.randn(3,2,5)
>>> Bs = torch.randn(3,5,4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
[-1.6706, -0.8097, -0.8025, -2.1183]],
[[ 4.2239, 0.3107, -0.5756, -0.2354],
[-1.4558, -0.3460, 1.5087, -0.8530]],
[[ 2.8153, 1.8787, -4.3839, -1.2112],
[ 0.3728, -2.1131, 0.0921, 0.8305]]])
# 等价操作
torch.bmm(As, Bs)
向量内积 dot
\[c=\sum_{i} a_{i} b_{i}=a_{i} b_{i} \]a = torch.arange(3)
b = torch.arange(3,6) # [3, 4, 5]
torch.einsum('i,i->', [a, b])
# tensor(14.)
# 等价操作
torch.dot(a, b)
矩阵内积 dot
\[c=\sum_{i} \sum_{j} A_{i j} B_{i j}=A_{i j} B_{i j} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->', [a, b])
tensor(145.)
哈达玛积
\[C_{i j}=A_{i j} B_{i j} \]a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
torch.einsum('ij,ij->ij', [a, b])
tensor([[ 0., 7., 16.],
[ 27., 40., 55.]])
外积 outer
\[C_{i j}=a_{i} b_{j} \]a = torch.arange(3)
b = torch.arange(3,7)
torch.einsum('i,j->ij', [a, b])
tensor([[ 0., 0., 0., 0.],
[ 3., 4., 5., 6.],
[ 6., 8., 10., 12.]])
einsum规则总结:
- 表达式由输入和输出两部分组成。例子,
ij->ji
- 输出可以省略,箭头也可以省略。输入中仅出现一次的字符将按照字母序构成输出。例子,
ba
完整的表达式是ba->ab
- 输入中多次出现的字符,将被用作求和。例子,
kj,ji
完整的表达式是kj,ji->ik
,矩阵乘法再相乘。 - 输出可以指定,但是输出中的每个字符必须在输入中出现至少一次,输出的每个字符在输出中只能出现最多一次。例子,
ab->aa
是非法的,ab->c
是非法的,ab->a
是合法的。 - 省略符
...
是用来跳过部分维度。例子,...ij,...jk
表示 batch 矩阵乘法。 - 在输出没有指定的情况下,省略符优先级高于普通字符。例子,
b...a
完整的表达式是b...a->...ab
,可以将一个形状为(a,b,c)
的矩阵变为形状为(b,c,a)
的矩阵。 - 允许多个矩阵输入,表达式中使用逗号分开不同矩阵输入的下标。例子,
i,i,i
表示将三个一维向量按位相乘,并相加。 - 除了箭头,其他任何地方都可以加空格。例子,
i j , j k -> ik
是合法的,ij,jk - > ik
是非法的。 - 输入的表达式,维度需要和输入的矩阵对上,不能多也不能少。比如一个 shape 为
(4,3,3)
的矩阵,表达式ab->a
是非法的,abc->
是合法的。
实际使用
实现multi headed attention
https://nn.labml.ai/transformers/mha.html
如何优雅地实现多头自注意力
计算注意力score:
\[Q K^{\top} or S_{i j b h}=\sum_{d} Q_{i b h d} K_{j b h d} \]# q k v均为 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key) # 理解为ibhd,jbhd->ibhj->ijbh
计算attention输出:
\[\underset{\text { seq }}{\operatorname{softmax}}\left(\frac{Q K^{\top}}{\sqrt{d_{k}}}\right) V \]# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k]
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]
标签:torch,函数,求和,einsum,矩阵,张量,arange,ij 来源: https://www.cnblogs.com/qftie/p/16245124.html