调整ARIMA进行预测:Python中的简单方法
作者:互联网
这篇文章将介绍一种直截了当的方法,可以估计与最先进的手动方法接近的参数。
我们将使用贝叶斯优化方法(Mango)在短短200次迭代中从108,000个可能选项中搜索最佳参数。
ARIMA时间序列预测模型非常适合具有趋势和季节性的序列。这是一个被广泛采用的经典模型,通常作为基准现代深度学习方法的基线。然而,估计其准确参数具有挑战性。研究人员和开发人员通常使用包括视觉绘图在内的试错方法。
ARIMA模型是什么?
ARIMA模型是“自动递归移动平均线”的缩写,是一类使用过去值来估计未来预测的模型。ARIMA模型由三个参数定义:p、d和q。
ARIMA模型在文献中研究了不同的变体。在这篇文章中,我们将使用statsmodels库中的实现。
整个笔记本显示一个简单的实现在这里可用。您可以为您的数据集修改此实现。根据需要创建单独的火车测试拆分。我简化了概述重要的调音步骤。
完整代码:使用芒果自动调音
import pandas as pd df = pd.read_csv('https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv') from statsmodels.tsa.arima.model import ARIMA from sklearn.metrics import mean_squared_error from mango import scheduler, Tuner def arima_objective_function(args_list): global data_values params_evaluated = [] results = [] for params in args_list: try: p,d,q = params['p'],params['d'], params['q'] trend = params['trend'] model = ARIMA(data_values, order=(p,d,q), trend = trend) predictions = model.fit() mse = mean_squared_error(data_values, predictions.fittedvalues) params_evaluated.append(params) results.append(mse) except: #print(f"Exception raised for {params}") #pass params_evaluated.append(params) results.append(1e5) #print(params_evaluated, mse) return params_evaluated, results param_space = dict(p= range(0, 30), d= range(0, 30), q =range(0, 30), trend = ['n', 'c', 't', 'ct'] ) conf_Dict = dict() conf_Dict['num_iteration'] = 200 data_values = list(df['#Passengers']) tuner = Tuner(param_space, arima_objective_function, conf_Dict) results = tuner.minimize() print('best parameters:', results['best_params']) print('best loss:', results['best_objective']) best parameters: {'d': 0, 'p': 17, 'q': 23, 'trend': 'ct'} best loss: 112.06886739549542
调音步骤
数据集:我们将使用一个简单的空中乘客数据集,记录航空公司乘客人数。
import pandas as pd df = pd.read_csv('https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv') df.head()
绘制系列图,以了解趋势和季节性
from matplotlib import pyplot as plt f = plt.figure() f.set_figwidth(15) f.set_figheight(6) plt.plot(df['#Passengers'], linewidth = 4, label = "original Series") plt.legend(fontsize=25) plt.xlabel('Months', fontsize = 25) plt.ylabel('Count', fontsize = 25) plt.show()
该数据集呈上升趋势,季节性为12个月。
传统上,一种方法可以使用领域知识从原始序列中去除趋势和季节性,然后使用剩余序列来预测未来。然而,我们将研究一种更直接的自动化方法。
如何自动调整参数?
我们将使用一个名为Mango的最先进的优化库来为我们的数据集找到最佳参数。让我们首先定义参数的范围。在这种优化方法中,我们定义了可能的参数范围。这个范围可能非常大,不需要精确。这些参数是从statsmodels库中定义的。
param_space = dict(p= range(0, 30), d= range(0, 30), q =range(0, 30), trend = ['n', 'c', 't', 'ct'] )
参数空间是使用python构造定义的:范围和列表。参数总可能组合的集合是30*30*30*4 = 108,000。因此,详尽的网格搜索非常耗时。我们将使用贝叶斯搜索优化器方法,在大约100次迭代内自动进行搜索。注意:根据您的数据集,范围的大小及其搜索空间可能会有所不同。定义一个大的搜索空间很好;让优化器为你做艰苦的工作。
训练ARIMA模型
要使用Mango,我们必须定义一个目标函数,该函数返回给定参数集的ARIMA模型错误。
from statsmodels.tsa.arima.model import ARIMA from sklearn.metrics import mean_squared_error from mango import scheduler, Tuner def arima_objective_function(args_list): global data_values params_evaluated = [] results = [] for params in args_list: try: p,d,q = params['p'],params['d'], params['q'] trend = params['trend'] model = ARIMA(data_values, order=(p,d,q), trend = trend) predictions = model.fit() mse = mean_squared_error(data_values, predictions.fittedvalues) params_evaluated.append(params) results.append(mse) except: #print(f"Exception raised for {params}") #pass params_evaluated.append(params) results.append(1e5) #print(params_evaluated, mse) return params_evaluated, results
我们从Mango库中获取参数,并返回参数及其结果。结果包括经过训练的ARIMA模型的错误。在这种情况下,错误是mean_squared_error。我们还包括try-catch语句,因为ARIMA模型可能不会对参数的每个组合/选择收敛。我们只返回模型工作的参数集。芒果内部优化使用这些参数,在很少的迭代中找到最佳模型(在本例中为100)。我们的目标是找到最小化错误函数的参数。
控制芒果迭代:配置参数。
来自芒果进口调度器,调谐器
from mango import scheduler, Tuner conf_Dict = dict() conf_Dict['num_iteration'] = 200 tuner = Tuner(param_space, arima_objective_function, conf_Dict)
可视化最佳模型预测
总的来说,我们看到总的可能参数组合非常大(108,000)。
def plot_arima(data_values, order = (1,1,1), trend = 'c'): print('final model:', order, trend) model = ARIMA(data_values, order=order, trend = trend) results = model.fit() error = mean_squared_error(data_values, results.fittedvalues) print('MSE error is:', error) from matplotlib import pyplot as plt f = plt.figure() f.set_figwidth(15) f.set_figheight(6) plt.plot(data_values, label = "original Series", linewidth = 4) plt.plot(results.fittedvalues, color='red', label = "Predictions", linestyle='dashed', linewidth = 3) plt.legend(fontsize = 25) plt.xlabel('Months', fontsize = 25) plt.ylabel('Count', fontsize = 25) plt.show() print(results['best_params']) order = (results['best_params']['p'], results['best_params']['d'], results['best_params']['q']) plot_arima(data_values, order=order, trend = results['best_params']['trend'])
标签:arima,statsmodels,python 来源: