其他分享
首页 > 其他分享> > 如何使用`__array__`添加`quantity.Quantity`-行为而不进行子类化?

如何使用`__array__`添加`quantity.Quantity`-行为而不进行子类化?

作者:互联网

quantity.Quantity是numpy.ndarray的子​​类,用于处理物理量的算术和转换.如何在不进行子类化的情况下使用它的算术?以下方法使用__array __-方法 – 但只能工作80%,如您在最后看到的那样:

class Numeric(object):
  def __init__(self, signal):
    self.signal = signal
    self._dimensionality = self.signal._dimensionality
    self.dimensionality = self.signal.dimensionality
  def __array__(self):
    return self.signal
  def __mul__(self, obj):
    return self.signal.__mul__(obj)
  def __rmul__(self, obj):
    return self.signal.__rmul__(obj)

有了这个,我可以做到:

import quantities as pq
import numpy as np

num = Numeric(pq.Quantity([1,2,3], 'mV'))
q = pq.Quantity([2,3,4], 'mV')
n = np.array([3,4,5])

以下所有操作都返回正确的单元 – 除了最后一个单元缺少单元:

print num * num
# [1 4 9] mV**2
print num * q
# [ 2  6 12] mV**2
print num * n
# [ 3  8 15] mV
print q * num
# [ 2  6 12] mV**2
print n * num
# [ 3  8 15] <------- no unit!

有什么想法,为了保持正确的单位需要修理什么?

编辑:算术操作的返回类型/值应相当于:

> num.signal * num.signal
> num.signal * q
> num.signal * n
> q * num.signal
> n * num.signal#这不起作用

解决方法:

当Python看到x * y时,会发生什么:

>如果y是x的子类 – >; y .__ rmul __(x)被调用

除此以外:

> x .__ mul __(y)被调用

IF x .__ mul __(y)返回NotImplemented(与raise NotImplementedError不同

> y .__ rmul __(x)被调用

因此,有两种方法可以调用__rmul__ – 子类ndarray,或者ndarray不能与Numeric相乘.

你无法继承子类,显然ndarray很乐意使用Numeric,所以. . .

值得庆幸的是,numpy人为这种情况做好了准备 – 答案在于__array_wrap__方法:

def __array_wrap__(self, out_arr, context=None):
    return type(self.signal)(out_arr, self.dimensionality)

我们使用原始信号类以及原始维度来为新的Numeric对象创建新信号.

整个位看起来像这样:

import quantities as pq
import numpy as np

class Numeric(object):
    def __init__(self, signal):
        self.signal = signal
        self.dimensionality = self.signal.dimensionality
        self._dimensionality = self.signal._dimensionality
    def __array__(self):
        return self.signal
    def __array_wrap__(self, out_arr, context=None):
        return type(self.signal)(out_arr, self.dimensionality)
    def __mul__(self, obj):
        return self.signal.__mul__(obj)
    def __rmul__(self, obj):
        return self.signal.__rmul__(obj)


num = Numeric(pq.Quantity([1,2,3], 'mV'))
q = pq.Quantity([2,3,4], 'mV')
n = np.array([3,4,5])

t = num * num
print type(t), t
t = num * q
print type(t), t
t = num * n
print type(t), t
t = q * num
print type(t), t
t = n * num
print type(t), t

运行时:

<class 'quantities.quantity.Quantity'> [1 4 9] mV**2
<class 'quantities.quantity.Quantity'> [ 2  6 12] mV**2
<class 'quantities.quantity.Quantity'> [ 3  8 15] mV
<class 'quantities.quantity.Quantity'> [ 2  6 12] mV**2
<class 'quantities.quantity.Quantity'> [ 3  8 15] mV

标签:python,numpy,units-of-measurement
来源: https://codeday.me/bug/20190609/1208234.html