Spark和Java的分层抽样
作者:互联网
我想确保我正在对数据的分层样本进行培训.
似乎Spark 2.1和更早版本通过JavaPairRDD.sampleByKey(…)和JavaPairRDD.sampleByKeyExact(…)对此提供了支持,如here所述.
但是:我的数据存储在Dataset< Row>中,而不是JavaPairRDD中.第一列是标签,所有其他都是功能(从libsvm格式的文件导入).
获得我的数据集实例的分层样本的最简单方法是什么,最后有一个Dataset< Row>.再次?
在某种程度上,这个问题与Dealing with unbalanced datasets in Spark MLlib有关.
该possible duplicate没有提及Dataset< Row>.根本不是Java.它没有回答我的问题.
解决方法:
好的,因为the question here的答案实际上不是Java的,所以我用Java重写了它.
推理还是一样的想法.我们仍在使用sampleByKeyExact.暂时没有开箱即用的奇迹功能(火花2.1.0)
所以,你去:
package org.awesomespark.examples;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.*;
import scala.Tuple2;
import java.util.Map;
public class StratifiedDatasets {
public static void main(String[] args) {
SparkSession spark = SparkSession.builder()
.appName("Stratified Datasets")
.getOrCreate();
Dataset<Row> data = spark.read().format("libsvm").load("sample_libsvm_data.txt");
JavaPairRDD<Double, Row> rdd = data.toJavaRDD().keyBy(x -> x.getDouble(0));
Map<Double, Double> fractions = rdd.map(Tuple2::_1)
.distinct()
.mapToPair((PairFunction<Double, Double, Double>) (Double x) -> new Tuple2(x, 0.8))
.collectAsMap();
JavaRDD<Row> sampledRDD = rdd.sampleByKeyExact(false, fractions, 2L).values();
Dataset<Row> sampledData = spark.createDataFrame(sampledRDD, data.schema());
sampledData.show();
sampledData.printSchema();
}
}
现在打包并提交:
$sbt package
[...]
// [success] Total time: 2 s, completed Jan 16, 2017 1:45:51 PM
$spark-submit --class org.awesomespark.examples.StratifiedDatasets target/scala-2.10/java-stratified-dataset_2.10-1.0.jar
[...]
// +-----+--------------------+
// |label| features|
// +-----+--------------------+
// | 0.0|(692,[127,128,129...|
// | 1.0|(692,[158,159,160...|
// | 1.0|(692,[124,125,126...|
// | 1.0|(692,[152,153,154...|
// | 1.0|(692,[151,152,153...|
// | 0.0|(692,[129,130,131...|
// | 1.0|(692,[99,100,101,...|
// | 0.0|(692,[154,155,156...|
// | 0.0|(692,[127,128,129...|
// | 1.0|(692,[154,155,156...|
// | 0.0|(692,[151,152,153...|
// | 1.0|(692,[129,130,131...|
// | 0.0|(692,[154,155,156...|
// | 1.0|(692,[150,151,152...|
// | 0.0|(692,[124,125,126...|
// | 0.0|(692,[152,153,154...|
// | 1.0|(692,[97,98,99,12...|
// | 1.0|(692,[124,125,126...|
// | 1.0|(692,[156,157,158...|
// | 1.0|(692,[127,128,129...|
// +-----+--------------------+
// only showing top 20 rows
// root
// |-- label: double (nullable = true)
// |-- features: vector (nullable = true)
对于python用户,您还可以查看我的答案Stratified sampling with pyspark.
标签:apache-spark,machine-learning,apache-spark-mllib,java 来源: https://codeday.me/bug/20191026/1935802.html