其他分享
首页 > 其他分享> > SparkMllib分类问题的模板代码

SparkMllib分类问题的模板代码

作者:互联网

import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

/**
  * DESC: 对分类问题的模板的代码
  * Complete data processing and modeling process steps:
  *- 1-准备SparkSession的环境
  *- 2-准备大数据的数据
  *- 3-读取数据并进行解析
  *- 4-数据的基本信息的查看
  *- 5-特征工程
  *- 6-准备算法
  *- 7-模型训练
  *- 8-模型预测
  *- 9-模型校验
  *- 10-模型保存
  *- 11-新数据预测
  *
  */
object ClassficationModelTest {

  var datapath = "D:\\BigData\\Workspace\\SparkMachineLearningTest\\SparkMllib_BigData32\\src\\main\\resources\\iris.csv"

  def main(args: Array[String]): Unit = {
    //    - 1-准备SparkSession的环境
    val conf: SparkConf = new SparkConf().setAppName("ClassficationModelTest").setMaster("local[*]")
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    //    - 2-准备大数据的数据
    val irisDF: DataFrame = spark.read.format("csv")
      .option("header", true)
      .option("inferschema", true)
      .option("sep", ",")
      .load(datapath)
    //    - 3-读取数据并进行解析
    irisDF.show(10, false)
    //    +------------+-----------+------------+-----------+-----------+
    //    |sepal_length|sepal_width|petal_length|petal_width|class      |
    //    +------------+-----------+------------+-----------+-----------+
    //    |5.1         |3.5        |1.4         |0.2        |Iris-setosa|
    //      |4.9         |3.0        |1.4         |0.2        |Iris-setosa|
    //      |4.7         |3.2        |1.3         |0.2        |Iris-setosa|
    //      |4.6         |3.1        |1.5         |0.2        |Iris-setosa|
    //    - 4-数据的基本信息的查看
    irisDF.printSchema()
    // 因为在写各种string类型数据的时候可能会有一些单词拼写错误,可以实现定义
    val sepal_length_feeature = "sepal_length"
    val sepal_width_feeature = "sepal_width"
    val petal_length_feeature = "petal_length"
    val petal_width_feeature = "petal_width"
    val class_label = "class"
    //    root
    //    |-- sepal_length: double (nullable = true)
    //    |-- sepal_width: double (nullable = true)
    //    |-- petal_length: double (nullable = true)
    //    |-- petal_width: double (nullable = true)
    //    |-- class: string (nullable = true)
    //    - 5-特征工程
    //5-1处理类别型的数据class
    val stringIndexer: StringIndexer = new StringIndexer()
      .setInputCol(class_label)
      .setOutputCol("classlabel")
    val stringIndexerModel: StringIndexerModel = stringIndexer.fit(irisDF)
    val indexDF: DataFrame = stringIndexerModel.transform(irisDF)
    //5-2处理分散的特征整合为特征向量
    val vectorAssembler: VectorAssembler = new VectorAssembler()
      .setInputCols(Array(sepal_length_feeature, sepal_width_feeature, petal_length_feeature, petal_width_feeature))
      .setOutputCol("features")
    val vecDF: DataFrame = vectorAssembler.transform(indexDF)
    //5-3VectorIndexer对类别值的索引化,加速构建决策树
    val vectorIndexer: VectorIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("vecindexFeatures")
      .setMaxCategories(20)
    val vectorIndexerModel: VectorIndexerModel = vectorIndexer.fit(vecDF)
    val vecindexerDF: DataFrame = vectorIndexerModel.transform(vecDF)
    vecindexerDF.show(10, false)
    //    - 6-准备算法
    val classifier: DecisionTreeClassifier = new DecisionTreeClassifier()
      .setLabelCol("classlabel")
      .setPredictionCol("prces")
      .setFeaturesCol("vecindexFeatures")
      .setMaxDepth(5)
      .setImpurity("gini")
    val Array(trainingSet, testSet): Array[Dataset[Row]] = vecindexerDF.randomSplit(Array(0.8, 0.2), seed = 1234L)
    //    - 7-模型训练
    val model: DecisionTreeClassificationModel = classifier.fit(trainingSet)
    //    - 8-模型预测
    val y_pred_train: DataFrame = model.transform(trainingSet)
    val y_pred_test: DataFrame = model.transform(testSet)
    y_pred_train.show(10, false)
    //    - 9-模型校验
    val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
      //"(f1|weightedPrecision|weightedRecall|accuracy)"
      .setMetricName("accuracy")
      .setPredictionCol("prces")
      .setLabelCol("classlabel")
    val acc_test: Double = evaluator.evaluate(y_pred_test)
    val acc_train: Double = evaluator.evaluate(y_pred_train)
    println("acc in trainset score is:", acc_train)
    println("acc in testset score is:", acc_test)
    //    (acc in trainset score is:,0.9920634920634921)
    //    (acc in testset score is:,0.9583333333333334)
    //    //    - 10-模型保存
    //    val datapath="D:\\BigData\\Workspace\\SparkMachineLearningTest\\SparkMllib_BigData32\\src\\main\\resources\\model1"
    //    model.save(datapath)
    //    //    - 11-新数据预测
    //    DecisionTreeClassificationModel.load(datapath)

  }
}

标签:feeature,val,petal,代码,SparkMllib,sepal,width,length,模板
来源: https://www.cnblogs.com/haojia/p/12396975.html