Python的torch.einsum方法
作者:互联网
涉及以下内容
简述
- 爱因斯坦求和约定(einsum)具有简洁优雅的规则,实现包括但不限于:向量内积,向量外积,矩阵乘法,转置和张量收缩等张量操作
例如
a = torch.rand(3, 4),b = torch.rand(4, 5),c = torch.einsum("ik,kj->ij", [a, b])
# einsum 的第一个参数 "ik,kj->ij" 描述张量的计算规则,且维度的字符只能是26个英文字母 'a' - 'z'
# einsum 的第一个参数可以不写包括箭头在内的右边部分,比如矩阵乘法 "ik,kj" 等价于 "ik,kj->ij" 输# 出保留输入只出现一次的索引,按字母表顺序排列
# einsum 的第一个参数支持 "..." 省略号,用于表示用户不关心的索引,
# einsum 的第二个参数 [a, b] 表示实际的输入张量列表,且真实维度需匹配规则
# 索引顺序可以任意,但 "ik,kj->ij" 如果写成 "ik,kj->ji" 后一将返回前一的转置
- 自由索引,箭头右边的索引,比如上述的 i 和 j
- 求和索引,只出现在箭头左边的索引,表示中间计算结果需要在此维度上求和后输出,比如上述的 k
实践
import torch
import numpy as np
# 1:矩阵乘法
a = torch.rand(2, 3)
b = torch.rand(3, 4)
ein_out = torch.einsum("ik,kj->ij", [a, b]).numpy() # ein_out = torch.einsum("ik,kj", [a, b]).numpy()
org_out = torch.mm(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 2:矩阵点乘
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6, 12).reshape(2, 3)
ein_out = torch.einsum('ij,ij->ij', [a, b]).numpy()
org_out = torch.mul(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 3:张量后两维乘法
a = torch.randn(2, 3, 5)
b = torch.randn(2, 5, 3)
ein_out = torch.einsum('ijk,ikl->ijl', [a, b]).numpy()
org_out = torch.matmul(a, b).numpy() # org_out = torch.bmm(a, b).numpy() # batch矩阵乘法
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 4:矩阵转置
a = torch.arange(6).reshape(2, 3)
ein_out = torch.einsum('ij->ji', [a]).numpy()
org_out = torch.transpose(a, 0, 1).numpy()
print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 5:张量后两维转置
a = torch.randn(1, 2, 3, 4, 5)
ein_out = torch.einsum('...ij->...ji', [a]).numpy()
org_out = a.permute(0, 1, 2, 4, 3).numpy()
print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 6:矩阵求和
a = torch.arange(6).reshape(2, 3)
ein_out = torch.einsum('ij->', a).numpy()
org_out = torch.sum(a).numpy()
ein_out_i = torch.einsum('ij->i', a).numpy()
org_out_i = torch.sum(a, dim=1).numpy()
ein_out_j = torch.einsum('ij->j', a).numpy()
org_out_j = torch.sum(a, dim=0).numpy()
print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
print("input:\n", a)
print("ein_out_i: \n", ein_out_i)
print("org_out_i: \n", org_out_i)
print("is org_out_i == ein_out_i ?", np.allclose(ein_out, org_out))
print("input:\n", a)
print("ein_out_j: \n", ein_out_j)
print("org_out_J: \n", org_out_j)
print("is org_out_j == ein_out_j ?", np.allclose(ein_out, org_out))
# 7:矩阵提取对角线元素
a = torch.arange(9).reshape(3, 3)
ein_out = torch.einsum('ii->i', a).numpy()
org_out = torch.diagonal(a, 0).numpy()
print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 8:矩阵向量乘法
a = torch.rand(3, 4)
b = torch.arange(4.0)
ein_out = torch.einsum('ik,k->i', [a, b]).numpy() # ein_out_k = torch.einsum('ik,k', [a, b]).numpy()
org_out = torch.mv(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out_k: \n", ein_out)
print("org_out_k: \n", org_out)
print("is org_out_k == ein_out_k ?", np.allclose(ein_out, org_out))
# 9:向量内积
a = torch.arange(3)
b = torch.arange(3, 6)
ein_out = torch.einsum('i,i->', [a, b]).numpy() # ein_out = torch.einsum('i,i', [a, b]).numpy()
org_out = torch.dot(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 10:向量外积
a = torch.arange(3)
b = torch.arange(3, 5)
ein_out = torch.einsum('i,j->ij', [a, b]).numpy() # ein_out = torch.einsum('i,j', [a, b]).numpy()
org_out = torch.outer(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 11:张量收缩
a = torch.randn(1, 3, 5, 7)
b = torch.randn(11, 33, 3, 55, 5)
ein_out = torch.einsum('pqrs,tuqvr->pstuv', [a, b]).numpy()
org_out = torch.tensordot(a, b, dims=([1, 2], [2, 4])).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
标签:einsum,Python,torch,print,org,out,numpy,ein 来源: https://blog.csdn.net/weixin_49371288/article/details/121314954