编程语言
首页 > 编程语言> > 如何在2D数组上加速python curve_fit?

如何在2D数组上加速python curve_fit?

作者:互联网

我必须在大量数据(5 000 000)上使用curve_fit numpy函数.
所以基本上我已经创建了一个2D数组.第一维是要执行的配件数量,第二维是用于配件的点数.

t = np.array([0 1 2 3 4])

for d in np.ndindex(data.shape[0]):
  try:
    popt, pcov = curve_fit(func, t, np.squeeze(data[d,:]), p0=[1000,100])
  except RuntimeError:
    print("Error - curve_fit failed")

可以使用多处理来加快整个过程,但是它仍然很慢.
有没有办法以“向量化”方式使用curve_fit?

解决方法:

一种加快速度的方法是在curve_fit中添加一些先验知识.

如果知道参数的期望范围,并且不需要精确到第100个有效数字的精度,则可以大大加快计算速度.

在下面的示例中,您将适合param1和param2:

t = np.array([0 1 2 3 4])
def func(t, param1, param2):
  return param1*t + param2*np.exp(t)

for d in np.ndindex(data.shape[0]):
  try:
    popt, pcov = curve_fit(func, t, np.squeeze(data[d,:]), p0=[1000,100], 
                           bounds=([min_param1, min_param2],[max_param1, max_param2]),
                           ftol=0.5, xtol=0.5)
  except RuntimeError:
    print("Error - curve_fit failed")

注意额外的关键参数界限,ftol和xtol.您可以阅读有关它们的内容here.

标签:performance,curve-fitting,python,numpy
来源: https://codeday.me/bug/20191120/2041366.html