其他分享
首页 > 其他分享> > 如何在Numpy中构建argsecondmax

如何在Numpy中构建argsecondmax

作者:互联网

在Numpy中,已经定义了argmax,但我需要argsecondmax,这基本上是第二个最大值.我怎么能这样做,我有点困惑?

解决方法:

寻找第N个最大指数

一个有效的可以使用np.argparition跳过排序和简单的paritition,切片时将给我们所需的索引.我们还要概括它以找到沿指定轴或全局的第N个最大的一个(类似于ndarray.argmax()),就像这样 –

def argNmax(a, N, axis=None):
    if axis is None:
        return np.argpartition(a.ravel(), -N)[-N]
    else:
        return np.take(np.argpartition(a, -N, axis=axis), -N, axis=axis)

样品运行 –

In [66]: a
Out[66]: 
array([[908, 770, 258, 534],
       [399, 376, 808, 750],
       [655, 654, 825, 355]])

In [67]: argNmax(a, N=2, axis=0)
Out[67]: array([2, 2, 1, 0])

In [68]: argNmax(a, N=2, axis=1)
Out[68]: array([1, 3, 0])

In [69]: argNmax(a, N=2) # global second largest index
Out[69]: 10

找到第N个最小的指数

为了找到沿着轴或全局的第N个最小的那个,我们将 –

def argNmin(a, N, axis=None):
    if axis is None:
        return np.argpartition(a.ravel(), N-1)[N-1]
    else:
        return np.take(np.argpartition(a, N-1, axis=axis), N-1, axis=axis)

样品运行 –

In [105]: a
Out[105]: 
array([[908, 770, 258, 534],
       [399, 376, 808, 750],
       [655, 654, 825, 355]])

In [106]: argNmin(a, N=2, axis=0)
Out[106]: array([2, 2, 1, 0])

In [107]: argNmin(a, N=2, axis=1)
Out[107]: array([3, 0, 1])

In [108]: argNmin(a, N=2)
Out[108]: 11

计时

为了透露使用argpartition与使用argsort进行实际排序的好处,如@pythonic833's post所示,这里是对全局argmax版本的快速运行时测试 –

In [70]: a = np.random.randint(0,99999,(1000,1000))

In [72]: %timeit np.argsort(a)[-2] # @pythonic833's soln
10 loops, best of 3: 40.6 ms per loop

In [73]: %timeit argNmax(a, N=2)
100 loops, best of 3: 2.12 ms per loop

标签:argmax,python,numpy,machine-learning
来源: https://codeday.me/bug/20190727/1549965.html