编程语言
首页 > 编程语言> > python – Numpy:找到两个三维阵列之间的欧氏距离

python – Numpy:找到两个三维阵列之间的欧氏距离

作者:互联网

给定,两个维度的三维数组(2,2,2):

A = [[[ 0,  0],
    [92, 92]],

   [[ 0, 92],
    [ 0, 92]]]

B = [[[ 0,  0],
    [92,  0]],

   [[ 0, 92],
    [92, 92]]]

如何有效地找到A和B中每个向量的欧几里德距离?

我尝试了for-loops但这些都很慢,而且我正按照(>>>>>> 2,2)的顺序使用3-D数组.

最终我想要一个表格矩阵:

C = [[d1, d2],
     [d3, d4]]

编辑:

我尝试过以下循环,但最大的问题是丢失了我想保留的尺寸.但距离是正确的.

[numpy.sqrt((A[row, col][0] - B[row, col][0])**2 + (B[row, col][1] -A[row, col][1])**2) for row in range(2) for col in range(2)]

解决方法:

以NumPy向量化思维方式思考,该方法将执行元素分化,沿最后一个轴进行平方和求和,最后得到平方根.因此,直接的实施将是 –

np.sqrt(((A - B)**2).sum(-1))

我们可以一次性使用np.einsum对最后一个轴进行平方和求和,从而使其效率更高,如此 –

subs = A - B
out = np.sqrt(np.einsum('ijk,ijk->ij',subs,subs))

numexpr module的另一种选择 –

import numexpr as ne
np.sqrt(ne.evaluate('sum((A-B)**2,2)'))

因为,我们在最后一个轴上使用长度为2的长度,我们可以将它们切片并将其馈送到评估方法.请注意,在evaluate字符串中无法进行切片.因此,修改后的实施将是 –

a0 = A[...,0]
a1 = A[...,1]
b0 = B[...,0]
b1 = B[...,1]
out = ne.evaluate('sqrt((a0-b0)**2 + (a1-b1)**2)')

运行时测试

功能定义 –

def sqrt_sum_sq_based(A,B):
    return np.sqrt(((A - B)**2).sum(-1))

def einsum_based(A,B):
    subs = A - B
    return np.sqrt(np.einsum('ijk,ijk->ij',subs,subs))

def numexpr_based(A,B):
    return np.sqrt(ne.evaluate('sum((A-B)**2,2)'))

def numexpr_based_with_slicing(A,B):
    a0 = A[...,0]
    a1 = A[...,1]
    b0 = B[...,0]
    b1 = B[...,1]
    return ne.evaluate('sqrt((a0-b0)**2 + (a1-b1)**2)')

计时 –

In [288]: # Setup input arrays
     ...: dim = 2
     ...: N = 1000
     ...: A = np.random.rand(N,N,dim)
     ...: B = np.random.rand(N,N,dim)
     ...: 

In [289]: %timeit sqrt_sum_sq_based(A,B)
10 loops, best of 3: 40.9 ms per loop

In [290]: %timeit einsum_based(A,B)
10 loops, best of 3: 22.9 ms per loop

In [291]: %timeit numexpr_based(A,B)
10 loops, best of 3: 18.7 ms per loop

In [292]: %timeit numexpr_based_with_slicing(A,B)
100 loops, best of 3: 8.23 ms per loop

In [293]: %timeit np.linalg.norm(A-B, axis=-1) #@dnalow's soln
10 loops, best of 3: 45 ms per loop

标签:python,vectorization,matrix,numpy,euclidean-distance
来源: https://codeday.me/bug/20190608/1198724.html