如何使用`__array__`添加`quantity.Quantity`-行为而不进行子类化?
作者:互联网
quantity.Quantity是numpy.ndarray的子类,用于处理物理量的算术和转换.如何在不进行子类化的情况下使用它的算术?以下方法使用__array __-方法 – 但只能工作80%,如您在最后看到的那样:
class Numeric(object):
def __init__(self, signal):
self.signal = signal
self._dimensionality = self.signal._dimensionality
self.dimensionality = self.signal.dimensionality
def __array__(self):
return self.signal
def __mul__(self, obj):
return self.signal.__mul__(obj)
def __rmul__(self, obj):
return self.signal.__rmul__(obj)
有了这个,我可以做到:
import quantities as pq
import numpy as np
num = Numeric(pq.Quantity([1,2,3], 'mV'))
q = pq.Quantity([2,3,4], 'mV')
n = np.array([3,4,5])
以下所有操作都返回正确的单元 – 除了最后一个单元缺少单元:
print num * num
# [1 4 9] mV**2
print num * q
# [ 2 6 12] mV**2
print num * n
# [ 3 8 15] mV
print q * num
# [ 2 6 12] mV**2
print n * num
# [ 3 8 15] <------- no unit!
有什么想法,为了保持正确的单位需要修理什么?
编辑:算术操作的返回类型/值应相当于:
> num.signal * num.signal
> num.signal * q
> num.signal * n
> q * num.signal
> n * num.signal#这不起作用
解决方法:
当Python看到x * y时,会发生什么:
>如果y是x的子类 – >; y .__ rmul __(x)被调用
除此以外:
> x .__ mul __(y)被调用
IF x .__ mul __(y)返回NotImplemented(与raise NotImplementedError不同
> y .__ rmul __(x)被调用
因此,有两种方法可以调用__rmul__ – 子类ndarray,或者ndarray不能与Numeric相乘.
你无法继承子类,显然ndarray很乐意使用Numeric,所以. . .
值得庆幸的是,numpy人为这种情况做好了准备 – 答案在于__array_wrap__方法:
def __array_wrap__(self, out_arr, context=None):
return type(self.signal)(out_arr, self.dimensionality)
我们使用原始信号类以及原始维度来为新的Numeric对象创建新信号.
整个位看起来像这样:
import quantities as pq
import numpy as np
class Numeric(object):
def __init__(self, signal):
self.signal = signal
self.dimensionality = self.signal.dimensionality
self._dimensionality = self.signal._dimensionality
def __array__(self):
return self.signal
def __array_wrap__(self, out_arr, context=None):
return type(self.signal)(out_arr, self.dimensionality)
def __mul__(self, obj):
return self.signal.__mul__(obj)
def __rmul__(self, obj):
return self.signal.__rmul__(obj)
num = Numeric(pq.Quantity([1,2,3], 'mV'))
q = pq.Quantity([2,3,4], 'mV')
n = np.array([3,4,5])
t = num * num
print type(t), t
t = num * q
print type(t), t
t = num * n
print type(t), t
t = q * num
print type(t), t
t = n * num
print type(t), t
运行时:
<class 'quantities.quantity.Quantity'> [1 4 9] mV**2
<class 'quantities.quantity.Quantity'> [ 2 6 12] mV**2
<class 'quantities.quantity.Quantity'> [ 3 8 15] mV
<class 'quantities.quantity.Quantity'> [ 2 6 12] mV**2
<class 'quantities.quantity.Quantity'> [ 3 8 15] mV
标签:python,numpy,units-of-measurement 来源: https://codeday.me/bug/20190609/1208234.html