其他分享
首页 > 其他分享> > 决策树简单实现

决策树简单实现

作者:互联网

一、分类

from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
X, y = iris.data, iris.target
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
plt.figure(figsize=(16,12))
tree.plot_tree(clf,filled=True) 

在这里插入图片描述

二、回归

在这里插入图片描述

print(__doc__)
# Import the necessary modules and libraries
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt

# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))  #::反方向

# Fit regression model训练
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)

# Predict预测
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)

# Plot the results绘图
plt.figure(figsize=(8,5))
plt.scatter(X, y, s=20, edgecolor="black",
            c="darkorange", label="data")#原始点集
plt.plot(X_test, y_1, color="cornflowerblue",
         label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)

plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()

在这里插入图片描述

标签:iris,plt,实现,tree,regr,简单,np,import,决策树
来源: https://blog.csdn.net/liaozhaocong/article/details/116426887