其他分享
首页 > 其他分享> > Spark UDAF 自定义函数

Spark UDAF 自定义函数

作者:互联网

需求

有udaf.json数据内容如下

{"name":"Michael","salary":3000}

{"name":"Andy","salary":4500}

{"name":"Justin","salary":3500}

{"name":"Berta","salary":4000}

 

求取 平均工资

 

●继承UserDefinedAggregateFunction方法重写说明

inputSchema:输入数据的类型

bufferSchema:产生中间结果的数据类型

dataType:最终返回的结果类型

deterministic:确保一致性,一般用true

initialize:指定初始值

update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)

merge:全局聚合(将每个分区的结果进行聚合)

evaluate:计算最终的结果

 

package SparkSql

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

/**
  * Created by 一个蔡狗 on 2020/4/13.
  *
  * UDAF  自定义函数
  *
  */
object UDAF_01 {

  //  继承   UserDefinedAggregateFunction   重写方法
  //  inputSchema:输入数据的类型
  //  bufferSchema:产生中间结果的数据类型
  //  dataType:最终返回的结果类型
  //  deterministic:确保一致性,一般用true
  //  initialize:指定初始值
  //  update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
  //  merge:全局聚合(将每个分区的结果进行聚合)
  //  evaluate:计算最终的结果

  class GetAvg extends UserDefinedAggregateFunction {

    //  inputSchema:输入数据的类型     StructType 表结构
    override def inputSchema: StructType = {
//       ::Nil   创建 list   ::Nil
      StructType(StructField("input",LongType)::Nil)


    }


    //  bufferSchema:产生中间结果的数据类型
    // sum   : 每次的 临时的 总和
    // total : 临时的总次数

    override def bufferSchema: StructType = {

      StructType(StructField("sum",LongType)::StructField("total",LongType)::Nil)

    }

    //  dataType:最终返回的结果类型
    override def dataType: DataType = {
      DoubleType
    }


    //  deterministic:确保一致性,一般用true
    override def deterministic: Boolean = {
      true

    }

    //初始化数据  1  设置   x 个变量  2 每个变量 进行 初始化数据
    override def initialize(buffer: MutableAggregationBuffer): Unit = {

      //  buffer(0) 作用 : 用于 记录临时的 数据  和
      buffer(0)=0L

      //  buffer(1) 作用 : 用于 记录临时的 数据  条数
      buffer(1)=0L

    }

    // RDD 中 有多个分区  update 计算一个 分区内的 数据
    //  update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

      // buffer  记录临时的 数据的 和     input 输入的数据
      buffer(0)=buffer.getLong(0)+input.getLong(0)

      //临时输入的总数量(条数)

      buffer(1)=buffer.getLong(1)+1


    }

    //merge:全局聚合(将每个分区的结果进行聚合)
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

      //累加 第一个分区的 金额总和   与  第二个 分区的 金额总和
      buffer1(0) =buffer1.getLong(0)+buffer2.getLong(0)
      //离家 第一个分区的次数  和 第二个 分区的 次数
      buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)


    }

    //  evaluate:计算最终的结果
    override def evaluate(buffer: Row): Any = {
      //平均值
      buffer.getLong(0).toDouble/buffer.getLong(1).toDouble
    }


  }


  def main(args: Array[String]): Unit = {


    //    1创建 SparkSeesion
    val spark: SparkSession = SparkSession.builder().master("local[*]").appName("01").getOrCreate()
    val udafJson: DataFrame = spark.read.json("E:\\udaf.json")
    udafJson.show()
    //注册一个 udaf 函数
    spark.udf.register("GetAvg", new GetAvg())
    //注册成一张表
    udafJson.createOrReplaceTempView("UdafTabel")
    //查看薪水       注册一个  GetAvg()  方法 实现 平均工资
    spark.sql("select GetAvg(salary)  from UdafTabel ").show()
    spark.sql("select avg(salary)  from UdafTabel ").show()
    //关闭 spark
    spark.stop()


  }

   //熟练掌握 基于 Spark的 UDAF

}

 

标签:自定义,buffer,getLong,分区,update,UDAF,Spark,spark,def
来源: https://blog.csdn.net/bbvjx1314/article/details/105497426