Python机器学习——预测分析核心算法(学习笔记五)
作者:互联网
Python机器学习——预测分析核心算法
第 6 章 集成方法
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.tree import DecisionTreeRegressor
from sklearn.externals.six import StringIO
import math
import matplotlib.pyplot as plt
import os
target_url = "http://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv"
data=pd.read_csv(target_url,sep=';')
xList=data.iloc[:,:-1]
yList=data.iloc[:,-1]
dtr=DecisionTreeRegressor(max_depth=3)
dtr.fit(xList,yList)
os.chdir(r'C:\Users\YY\Desktop\study\python\python_mechina_learning\06')
with open('temp1.doc','w') as f:
f=tree.export_graphviz(dtr,out_file=f)
简单回归问题的决策树训练 -simpleTree.py
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.tree import DecisionTreeRegressor
from sklearn.externals.six import StringIO
import matplotlib.pyplot as plt
nPoints=100
xPlot=[(float(i)/float(nPoints)-0.5) for i in range(nPoints+1)]
x=[[s] for s in xPlot]
np.random.seed(1)
y=[s+np.random.normal(scale=0.1) for s in xPlot]
plt.plot(xPlot,y)
plt.axis('tight')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()
dtr=DecisionTreeRegressor(max_depth=1)
dtr.fit(x,y)
import os
os.chdir(r'C:\Users\YY\Desktop\study\python\python_mechina_learning\06')
with open('temp2.doc','w') as f:
f=tree.export_graphviz(dtr,out_file=f)
yHat=dtr.predict(x)
plt.figure()
plt.plot(xPlot,y,label='True y')
plt.plot(xPlot,yHat,label='Tree Prediction',linestyle='--')
plt.legend(bbox_to_anchor=(1,0.2))
plt.axis('tight')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
sse=[]
xMin=[]
for i in range(1,len(xPlot)):
lhList=list(xPlot[0:i])
rhList=list(xPlot[i:len(xPlot)])
lhAvg=np.mean(lhList)
rhAvg=np.mean(rhList)
lhSse=sum([(s-lhAvg)*(s-lhAvg) for s in lhList])
rhSse=sum([(s-rhAvg)*(s-rhAvg) for s in rhList])
sse.append(lhSse+rhSse)
xMin.append(max(lhList))
plt.plot(range(1,len(xPlot)),sse)
plt.xlabel('Split Point Index')
plt.ylabel('Sum Square Error')
plt.show()
minSse=min(sse)
idxMin=sse.index(minSse)
print(xMin[idxMin])
dtr2=DecisionTreeRegressor(max_depth=6)
dtr2.fit(x,y)
yHat=dtr2.predict(x)
plt.figure()
plt.plot(xPlot,y,label='True y')
plt.plot(xPlot,yHat,label='predict y',linestyle='--')
plt.legend(bbox_to_anchor=(1,0.2))
plt.axis('tight')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
标签:plot,plt,Python,tree,学习,算法,dtr,xPlot,import 来源: https://blog.csdn.net/weixin_41221502/article/details/87917046