数据库
首页 > 数据库> > SparkSQL的UDF函数和UDAF函数

SparkSQL的UDF函数和UDAF函数

作者:互联网


文章目录


UDF函数:用户自定义函数

	SparkConf conf = new SparkConf();
		conf.setMaster("local");
		conf.setAppName("udf");
		JavaSparkContext sc = new JavaSparkContext(conf);
		SQLContext sqlContext = new SQLContext(sc);
		JavaRDD<String> parallelize = sc.parallelize(Arrays.asList("zhangsan","lisi","wangwu"));
		JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() {
			private static final long serialVersionUID = 1L;

			@Override
			public Row call(String s) throws Exception {
				return RowFactory.create(s);
			}
		});

		/**
		 * 动态创建Schema方式加载DF
		 */
		List<StructField> fields = new ArrayList<StructField>();
		fields.add(DataTypes.createStructField("name", DataTypes.StringType,true));
		StructType schema = DataTypes.createStructType(fields);
		DataFrame df = sqlContext.createDataFrame(rowRDD,schema);
		df.registerTempTable("user");
		/**
		 * 根据UDF函数参数的个数来决定是实现哪一个UDF  UDF1,UDF2。。。。UDF1xxx
		 */
		sqlContext.udf().register("StrLen",new UDF2<String, Integer, Integer>() {
			private static final long serialVersionUID = 1L;
			@Override
			public Integer call(String t1, Integer t2) throws Exception {
				return t1.length() + t2;
			}
		} ,DataTypes.IntegerType );
		sqlContext.sql("select name ,StrLen(name,100) as length from user").show();
		sc.stop();

UDAF函数: 用户自定义聚合函数

   SparkConf conf = new SparkConf();
        conf.setMaster("local").setAppName("udaf");
        conf.set("spark.sql.shuffle.partitions", "1");
        JavaSparkContext sc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(sc);
        JavaRDDparallelize = sc.parallelize(
                Arrays.asList("zhangsan", "lisi", "wangwu", "zhangsan", "zhangsan", "lisi","zhangsan", "lisi", "wangwu", "zhangsan", "zhangsan", "lisi"),2);
        JavaRDDrowRDD = parallelize.map(new Function() {
            private static final long serialVersionUID = 1L;
            @Override
            public Row call(String s) throws Exception {
                return RowFactory.create(s);
            }
        });
        Listfields = new ArrayList();
        fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
        StructType schema = DataTypes.createStructType(fields);
        DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
        df.registerTempTable("user");

注册一个UDAF函数,实现统计相同值的个数,注意:这里可以自定义一个类继承UserDefinedAggregateFunction类

sqlContext.udf().register("StringCount", new UserDefinedAggregateFunction() {

            private static final long serialVersionUID = 1L;

            /**
             * 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果
             */
            @Override
            public void initialize(MutableAggregationBuffer buffer) {
                buffer.update(0, 0);
                System.out.println("init ....." + buffer.get(0));

            }
            /**
             * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
             * buffer.getInt(0)获取的是上一次聚合后的值
             * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
             * 大聚和发生在reduce端.
             * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
             */
            @Override
            public void update(MutableAggregationBuffer buffer, Row arg1) {
                buffer.update(0, buffer.getInt(0) + 1);
                System.out.println("update.....buffer" + buffer.toString() + " | row" + arg1.toString() );
            }

            /**
             * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
             * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
             * buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值
             * buffer2.getInt(0) : 这次计算传入进来的update的结果
             * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
             */

            public void merge(MutableAggregationBuffer buffer1, Row arg1) {
                // 2 3  4  5  6  7
                // 0 + 2 = 2
                // 2 + 3 = 5
                // 5 + 4  = 9
                buffer1.update(0, buffer1.getInt(0) + arg1.getInt(0));
                System.out.println("merge.....buffer : " + buffer1.toString() + "| row" + arg1.toString() );
            }

            /**
             * 在进行聚合操作的时候所要处理的数据的结果的类型
             */
            @Override
            public StructType bufferSchema() {
                return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("bffer", DataTypes.IntegerType, true)));
            }

            /**
             * 最后返回一个和DataType的类型要一致的类型,返回UDAF最后的计算结果
             */
            @Override
            public Object evaluate(Row row) {
                return row.getInt(0);
            }

            /**
             * 指定UDAF函数计算后返回的结果类型
             */
            @Override
            public DataType dataType() {
                return DataTypes.IntegerType;
            }

            /**
             * 指定输入字段的字段及类型
             */
            @Override
            public StructType inputSchema() {
                return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("name", DataTypes.StringType, true)));
            }
            /**
             * 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。
             */
            @Override
            public boolean deterministic() {
                return true;
            }
        });
        sqlContext.sql("select name ,StringCount(name) as number from user group by name").show();
        sc.stop();

               

标签:函数,buffer,UDAF,UDF,Override,new,DataTypes,public,sqlContext
来源: https://blog.51cto.com/u_13985831/2836519