编程语言
首页 > 编程语言> > 机器学习算法(一元线性回归)

机器学习算法(一元线性回归)

作者:互联网

import matplotlib.pyplot  as plt
import numpy as np
import pandas as pd
from sklearn import datasets, linear_model


# 读取所需数据
def get_data(file_name):
    data = pd.read_csv(file_name)  # 获取Dataframe对象
    X_parameter = []
    Y_parameter = []
    for single_square_feet, single_price_value in zip(data['square_feet'], data['price']):
        X_parameter.append([float(single_square_feet)])     #加中括号变成二维数组,x,以及用这个x预测的y值
        Y_parameter.append(float(single_price_value))
    return X_parameter, Y_parameter


# 拟合线性模型
def linear_model_main(X_parameters, Y_parameters, predict_value):
    regr = linear_model.LogisticRegression()  # 创建线性回归对象
    regr.fit(X_parameters, Y_parameters)  # 拟合
    predict_outcome = regr.predict(predict_value)  # 调用线性回归对象的预测方法
    predictions = {}  # 定义一个空字典,存储拟合得到的斜率和截距,预测值
    predictions['intercept'] = regr.intercept_
    predictions['coefficient'] = regr.coef_
    predictions['predicted_value'] = predict_outcome
    return predictions


# 显示拟合线性模型的结果
def show_linear_line(X_parameters, Y_parameters):
    regr = linear_model.LinearRegression()
    regr.fit(X_parameters, Y_parameters)
    plt.scatter(X_parameters, Y_parameters, color='blue')
    plt.plot(X_parameters, regr.predict(X_parameters), color='red', linewidth=4)
    # plt.xticks(())  # 参数是xtick位置的列表。和一个可选参数。如果将一个空列表作为参数传递,则它将删除所有xticks
    # plt.yticks(())
    plt.show()


X, Y = get_data('input_data.csv')  # 传入所需数据
predictvalue = 700
result = linear_model_main(X, Y, predictvalue)  # 结果字典
print("Intercept value:", result['intercept'])
print("Coefficient:", result['coefficient'])
print("Predicted value:", result['predicted_value'])
show_linear_line(X, Y)

 

 

标签:一元,plt,linear,parameters,predict,regr,value,算法,线性
来源: https://blog.csdn.net/qq_42433311/article/details/121293958