如何在2D数组上加速python curve_fit?
作者:互联网
我必须在大量数据(5 000 000)上使用curve_fit numpy函数.
所以基本上我已经创建了一个2D数组.第一维是要执行的配件数量,第二维是用于配件的点数.
t = np.array([0 1 2 3 4])
for d in np.ndindex(data.shape[0]):
try:
popt, pcov = curve_fit(func, t, np.squeeze(data[d,:]), p0=[1000,100])
except RuntimeError:
print("Error - curve_fit failed")
可以使用多处理来加快整个过程,但是它仍然很慢.
有没有办法以“向量化”方式使用curve_fit?
解决方法:
一种加快速度的方法是在curve_fit中添加一些先验知识.
如果知道参数的期望范围,并且不需要精确到第100个有效数字的精度,则可以大大加快计算速度.
在下面的示例中,您将适合param1和param2:
t = np.array([0 1 2 3 4])
def func(t, param1, param2):
return param1*t + param2*np.exp(t)
for d in np.ndindex(data.shape[0]):
try:
popt, pcov = curve_fit(func, t, np.squeeze(data[d,:]), p0=[1000,100],
bounds=([min_param1, min_param2],[max_param1, max_param2]),
ftol=0.5, xtol=0.5)
except RuntimeError:
print("Error - curve_fit failed")
注意额外的关键参数界限,ftol和xtol.您可以阅读有关它们的内容here.
标签:performance,curve-fitting,python,numpy 来源: https://codeday.me/bug/20191120/2041366.html