其他分享
首页 > 其他分享> > 使用numba优化Jaccard距离性能

使用numba优化Jaccard距离性能

作者:互联网

我正在尝试使用Numba在python中实现最快的jaccard距离版本

@nb.jit()
def nbjaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))

def jaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))


%%timeit
nbjaccard("compare this string","compare a different string")

–12.4毫秒

%%timeit 
jaccard("compare this string","compare a different string")

–3.87毫秒

为什么numba版本需要更长的时间?有办法提高速度吗?

解决方法:

在我看来,允许纯对象模式的numba函数有点设计错误(或者,如果numba意识到整个函数都使用python对象,则不会发出警告),因为它们通常比纯python函数要慢一些.

Numba非常强大(与C扩展名或Cython相比,类型分派和您可以编写不带类型声明的python代码-确实很棒),但仅在它支持以下操作时:

> Supported Python features in numba
> Supported NumPy features in numba

这意味着“ nopython”模式不支持未列出的任何操作.如果numba必须降到“object mode”,请注意:

object mode

A Numba compilation mode that generates code that handles all values as Python objects and uses the Python C API to perform all operations on those objects. Code compiled in object mode will often run no faster than Python interpreted code, unless the Numba compiler can take advantage of loop-jitting.

这就是您的情况:您纯粹在对象模式下操作:

>>> nbjaccard.inspect_types()

[...]
# --- LINE 3 --- 
#   seq1 = arg(0, name=seq1)  :: pyobject
#   seq2 = arg(1, name=seq2)  :: pyobject
#   $0.1 = global(set: <class 'set'>)  :: pyobject
#   $0.3 = call $0.1(seq1)  :: pyobject
#   $0.4 = global(set: <class 'set'>)  :: pyobject
#   $0.6 = call $0.4(seq2)  :: pyobject
#   set1 = $0.3  :: pyobject
#   set2 = $0.6  :: pyobject

set1, set2 = set(seq1), set(seq2)

# --- LINE 4 --- 
#   $const0.7 = const(int, 1)  :: pyobject
#   $0.8 = global(len: <built-in function len>)  :: pyobject
#   $0.11 = set1 & set2  :: pyobject
#   $0.12 = call $0.8($0.11)  :: pyobject
#   $0.13 = global(float: <class 'float'>)  :: pyobject
#   $0.14 = global(len: <built-in function len>)  :: pyobject
#   $0.17 = set1 | set2  :: pyobject
#   $0.18 = call $0.14($0.17)  :: pyobject
#   $0.19 = call $0.13($0.18)  :: pyobject
#   $0.20 = $0.12 / $0.19  :: pyobject
#   $0.21 = $const0.7 - $0.20  :: pyobject
#   $0.22 = cast(value=$0.21)  :: pyobject
#   return $0.22

return 1 - len(set1 & set2) / float(len(set1 | set2))

如您所见,每个单独的操作都在Python对象上进行操作(如每行结尾的::: pyobject所示).这是因为numba不支持str和set.因此,绝对没有比这更快的了.除了您有一个想法如何使用numpy数组或均质列表(数字类型)解决此问题.

在我的计算机上,时间差要大得多(使用numba 0.32.0),但是各个时间却要快得多-微秒(10 **-6秒)而不是毫秒(10 **-3秒):

%timeit nbjaccard("compare this string","compare a different string")
10000 loops, best of 3: 84.4 µs per loop

%timeit jaccard("compare this string","compare a different string")
100000 loops, best of 3: 15.9 µs per loop

请注意,默认情况下,jit为lazy,因此应该在执行时间计时之前进行第一个调用-因为它包括编译代码的时间.

不过,您可以执行一种优化:如果您知道两个集合的交集,则可以计算并集的长度(如@Paul Hankin在他现在删除的答案中提到的那样):

len(union) = len(set1) + len(set2) - len(intersection)

这将导致以下(纯python)代码:

def jaccard2(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    num_intersection = len(set1 & set2)
    return 1 - num_intersection / float(len(set1) + len(set2) - num_intersection)

%timeit jaccard2("compare this string","compare a different string")
100000 loops, best of 3: 13.7 µs per loop

速度不快-但更好.

如果您使用,则还有一些改进的空间:

%load_ext cython

%%cython
def cyjaccard(seq1, seq2):
    cdef set set1 = set(seq1)
    cdef set set2 = set()

    cdef Py_ssize_t length_intersect = 0

    for char in seq2:
        if char not in set2:
            if char in set1:
                length_intersect += 1
            set2.add(char)

    return 1 - (length_intersect / float(len(set1) + len(set2) - length_intersect))

%timeit cyjaccard("compare this string","compare a different string")
100000 loops, best of 3: 7.97 µs per loop

这里的主要优点是,只需进行一次迭代就可以创建set2并计算相交中的元素数(根本不需要创建相交集)!

标签:performance,numba,python,cython
来源: https://codeday.me/bug/20191111/2019928.html