numba的有效平方欧几里德距离代码是否比numpy的有效平方欧氏距离代码慢?
作者:互联网
我修改了(Why this numba code is 6x slower than numpy code?)中最有效的代码,以便它可以处理x1为(n,m)
@nb.njit(fastmath=True,parallel=True)
def euclidean_distance_square_numba_v5(x1, x2):
res = np.empty((x1.shape[0], x2.shape[0]), dtype=x2.dtype)
for a_idx in nb.prange(x1.shape[0]):
for o_idx in range(x2.shape[0]):
val = 0.
for i_idx in range(x2.shape[1]):
tmp = x1[a_idx, i_idx] - x2[o_idx, i_idx]
val += tmp * tmp
res[a_idx, o_idx] = val
return res
但是,更有效的numpy版本仍然不是更有效:
def euclidean_distance_square_einsum(x1, x2):
return np.einsum('ij,ij->i', x1, x1)[:, np.newaxis] + np.einsum('ij,ij->i', x2, x2) - 2*np.dot(x1, x2.T)
输入为
a = np.zeros((1000000,512), dtype=np.float32)
b = np.zeros((100, 512), dtype=np.float32)
我得到的时间是numba代码为2.4723422527313232和numpy代码为0.8260958194732666.
解决方法:
是的,这是预期的.
您必须意识到的第一件事:点积是numpy版本的主力军,此处适用于较小的数组:
>>> def only_dot(x1, x2):
return - 2*np.dot(x1, x2.T)
>>> a = np.zeros((1000,512), dtype=np.float32)
>>> b = np.zeros((100, 512), dtype=np.float32)
>>> %timeit(euclidean_distance_square_einsum(a,b))
6.08 ms ± 312 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit(euclidean_only_dot(a,b))
5.25 ms ± 330 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
也就是说,其中有85%的时间都花在了上面.
当您查看numba代码时,看起来像是矩阵矩阵乘法的某种奇怪/不寻常/更复杂的版本-例如,可以看到相同的三个循环.
因此,基本上,您正在尝试击败最好的最佳算法之一.例如,这里是somebody trying to do it and failing.我的安装使用的是Intel的MKL版本,该版本必须比默认实现更复杂,默认实现为here.
有时候,在享受了整个乐趣之后,人们不得不承认自己的“重新发明的轮子”不如最先进的轮子……但是只有这样,人们才能真正欣赏它的性能.
标签:python,numpy,numba 来源: https://codeday.me/bug/20191013/1907557.html