python-numba中的性能嵌套循环
作者:互联网
出于性能原因,除了NumPy之外,我还开始使用Numba.我的Numba算法正在运行,但是我觉得它应该更快.有一点使它放慢了速度.这是代码片段:
@nb.njit
def rfunc1(ws, a, l):
gn = a**l
for x1 in range(gn):
for x2 in range(gn):
for x3 in range(gn):
y = 0.0
for i in range(1, l):
if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
numpy.all(ws[x1][i:l] == ws[x3][i:l]):
y += 1
if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
numpy.all(ws[x1][i:l] == ws[x3][i:l]):
y += 1
我认为if命令会降低它的速度.有没有更好的办法? (我在此处尝试实现的功能与先前发布的问题有关:Count possibilites for single crossovers)ws是大小为(gn,l)的NumPy数组,其中包含0和1
解决方法:
鉴于希望确保所有项目相等的逻辑,您可以利用以下事实:如果有任何一项不相等,则可以使计算短路(即停止比较).我稍微修改了原始函数,以使(1)您不会重复相同的比较两次,并且(2)在所有嵌套循环中求和,因此可以比较返回值:
@nb.njit
def rfunc1(ws, a, l):
gn = a**l
ysum = 0
for x1 in range(gn):
for x2 in range(gn):
for x3 in range(gn):
y = 0.0
for i in range(1, l):
if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]):
y += 1
ysum += 1
return ysum
@nb.njit
def rfunc2(ws, a, l):
gn = a**l
ysum = 0
for x1 in range(gn):
for x2 in range(gn):
for x3 in range(gn):
y = 0.0
for i in range(1, l):
incr_y = True
for j in range(i):
if ws[x1,j] != ws[x2,j]:
incr_y = False
break
if incr_y is True:
for j in range(i,l):
if ws[x1,j] != ws[x3,j]:
incr_y = False
break
if incr_y is True:
y += 1
ysum += 1
return ysum
我不知道完整的功能是什么样子,但是希望这可以帮助您开始正确的道路.
现在来看一些时间:
l = 7
a = 2
gn = a**l
ws = np.random.randint(0,2,size=(gn,l))
In [23]:
%timeit rfunc1(ws, a , l)
1 loop, best of 3: 2.11 s per loop
%timeit rfunc2(ws, a , l)
1 loop, best of 3: 39.9 ms per loop
In [27]: rfunc1(ws, a , l)
Out[27]: 131919
In [30]: rfunc2(ws, a , l)
Out[30]: 131919
这使您的速度提高了50倍.
标签:python,numpy,numba 来源: https://codeday.me/bug/20191012/1897077.html