编程语言
首页 > 编程语言> > Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer

Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer

作者:互联网

Alink漫谈(十八) :源码解析 之 多列字符串编码MultiStringIndexer

目录

0x00 摘要

Alink 是阿里巴巴基于实时计算引擎 Flink 研发的新一代机器学习算法平台,是业界首个同时支持批式算法、流式算法的机器学习平台。本文将带领大家来分析Alink中 MultiStringIndexer 的实现。

因为Alink的公开资料太少,所以以下均为自行揣测,肯定会有疏漏错误,希望大家指出,我会随时更新。

本文缘由是想分析GBDT,发现GBDT涉及到MultiStringIndexer的使用,所以只能先分析MultiStringIndexer 。

0x01 概念

Alink的官方介绍是:MultiStringIndexer训练组件的作用是训练一个模型用于将多列字符串映射为整数。

具体来说,StringIndexer(字符串-索引变换)将标签的"字符串列"编码为"标签索引的列"。

以这些输入为例:

("football", "can"),
("football", "hhh"),
("football", "zzz"),
("basketball", "zzz"),
("basketball", "can"),
("tennis", "can")

对于第一列,MultiStringIndexer 对数据集的label进行重新编号。按label出现的频次,转换成0 ~ numOfLabels - 1(分类个数)。如果是按照从高到低排序,则频次最高的转换为0,以此类推,比如:

在应用StringIndexer对labels进行重新编号后,带着这些编号后的label对数据进行了训练,并接着对其他数据进行了预测,得到预测结果,预测结果的label也是重新编号过的,因此需要转换回来。

0x02 示例代码

示例代码如下,本示例代码中,是按照升序排列,即football总数为3,则其idx为3,tennis个数为1,其idx为0:

public class MultiStringIndexerExample {
    static AlgoOperator getData(boolean isBatch) {
        Row[] array = new Row[] {
                Row.of("football", "can"),
                Row.of("football", "hhh"),
                Row.of("football", "zzz"),
                Row.of("basketball", "zzz"),
                Row.of("basketball", "can"),
                Row.of("tennis", "can")
        };

        if (isBatch) {
            return new MemSourceBatchOp(
                    Arrays.asList(array), new String[] {"a", "b"});
        } else {
            return new MemSourceStreamOp(
                    Arrays.asList(array), new String[] {"a", "b"});
        }
    }

    public static void main(String[] args) throws Exception {
        BatchOperator data = (BatchOperator)getData(true);
        MultiStringIndexer stringindexer = new MultiStringIndexer()
                .setSelectedCols("a", "b")
                .setOutputCols("a_indexed", "b_indexed")
                .setStringOrderType("frequency_asc");
        stringindexer.fit(data).transform(data).print();
    }
}

输出如下:

a|b|a_indexed|b_indexed
-|-|---------|---------
football|can|2|2
football|hhh|2|0
football|zzz|2|1
basketball|zzz|1|1
basketball|can|1|2
tennis|can|0|2

转换成表格看的更清楚。

a b a_indexed b_indexed
football can 2 2
football hhh 2 0
football zzz 2 1
basketball zzz 1 1
basketball can 1 2
tennis can 0 2

0x03 总体逻辑

我们先给出一个流程图

老套路,我们从 MultiStringIndexerTrainBatchOp.linkFrom开始挖掘。

@Override
public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);

    // 示例中有 .setSelectedCols("a", "b"),这里是取出具体列名字
    final String[] selectedColNames = getSelectedCols();
    // 获取列的类型
    final String[] selectedColSqlType = new String[selectedColNames.length];
    for (int i = 0; i < selectedColNames.length; i++) {
        selectedColSqlType[i] = FlinkTypeConverter.getTypeString(
            TableUtil.findColTypeWithAssertAndHint(in.getSchema(), selectedColNames[i]));
    }

// runtime打印数据
selectedColNames = {String[2]@2536} 
 0 = "a"
 1 = "b"
selectedColSqlType = {String[2]@2537} 
 0 = "VARCHAR"
 1 = "VARCHAR"
  
    // 获取选取列对应的数据
    DataSet<Row> inputRows = in.select(selectedColNames).getDataSet();
    // 
    DataSet<Tuple3<Integer, String, Long>> indexedToken =
        StringIndexerUtil.indexTokens(inputRows, getStringOrderType(), 0L, true);

    DataSet<Row> values = indexedToken
        .mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() {
            @Override
            public void mapPartition(Iterable<Tuple3<Integer, String, Long>> values, Collector<Row> out)
                throws Exception {
                Params meta = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {           
                    // 第一个task会做这个计算,就是把列名,列类型作为元数据传送
                    meta = new Params().set(HasSelectedCols.SELECTED_COLS, selectedColNames)
                        .set(HasSelectedColTypes.SELECTED_COL_TYPES, selectedColSqlType);
                }
 
// runtime打印数据              
meta = {Params@9311} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
 params = {HashMap@9316}  size = 2              
              
                new MultiStringIndexerModelDataConverter().save(Tuple2.of(meta, values), out);
            }
        })
        .name("build_model");

    this.setOutput(values, new MultiStringIndexerModelDataConverter().getModelSchema());
    return this;
}

训练过程总体逻辑总结如下:

下面具体剖析后两个阶段。

0x04 Add Index to Token

这部分就是给各个列的不同字串赋予连续的indices。每列的 indices 彼此不相关。

具体是由StringIndexerUtil.indexTokens 做到的。

public static DataSet<Tuple3<Integer, String, Long>> indexTokens(
    DataSet<Row> data, HasStringOrderTypeDefaultAsRandom.StringOrderType orderType,
    final long startIndex, final boolean ignoreNull) {
    		case FREQUENCY_ASC:
                return indexSortedByFreq(data, startIndex, ignoreNull, true);
}

4.1 合并计算单词个数

indexSortedByFreq会调用countTokens来计算单词个数,所以我们先看countTokens。

countTokens的作用是按照 "列idx","word" 来合并计算单词个数,比如第一列中,football这个单词的个数是3,则返回三元组是 <0,football,3>,其中列的idx从0开始计算。

具体逻辑如下:

4.1.1 打散输入数据

其中 flattenTokens 的作用是把输入数据 Row 给打散,返回 A DataSet of tuples of column index and token.。

比如对于 Row.of("football", "can") 这个输入,flattenTokens 使用 out.collect(Tuple2.of(i, String.valueOf(o))); 输出两个Tuple2。

value = {Row@9212} "football,can"
 fields = {Object[2]@9215} 
  0 = "football"
  1 = "can"
  
输出 <0, "football"> 和 <1, "can">

4.1.2 分组计算个数

这是通过flattenTokens的结果进行 map,groupBy,reduce的一系列操作完成的。

具体代码如下:

public static DataSet<Tuple3<Integer, String, Long>> countTokens(DataSet<Row> data, final boolean ignoreNull) {
    return flattenTokens(data, ignoreNull) // 把输入数据 Row 给打散
        .map(new MapFunction<Tuple2<Integer, String>, Tuple3<Integer, String, Long>>() {
            @Override
            public Tuple3<Integer, String, Long> map(Tuple2<Integer, String> value) throws Exception {
                return Tuple3.of(value.f0, value.f1, 1L); // 输出<column idx, word, 1L>,比如 <0, "football", 1L> 
            }
        })
        .groupBy(0, 1) // 按照 "列idx","word" 来分组
        .reduce(new ReduceFunction<Tuple3<Integer, String, Long>>() {
            @Override
            public Tuple3<Integer, String, Long> reduce(Tuple3<Integer, String, Long> value1, Tuple3<Integer, String, Long> value2) throws Exception {
                value1.f2 += value2.f2;
                return value1; // 按照 "列idx","word" 来合并计算单词个数
            }
        })
        .name("count_tokens");
}

// reduce之后发出
value1 = {Tuple3@9284} "(0,football,3)"
 f0 = {Integer@9226} 0
 f1 = "football"
 f2 = {Long@9295} 3

4.2 合并计算单词个数

前面 countTokens的 返回三元组是 <列idx","word" ,词频>,其中列的idx从0开始计算。

indexSortedByFreq会对countTokens返回的结果<"列idx","word",词频>处理;

具体代码如下:

public static DataSet<Tuple3<Integer, String, Long>> indexSortedByFreq(
    DataSet<Row> data, final long startIndex, final boolean ignoreNull, final boolean isAscending) {
    return countTokens(data, ignoreNull)
        .groupBy(0) //按照 列idx 做分组
        .sortGroup(2, isAscending ? Order.ASCENDING : Order.DESCENDING) //按照单词个数排序
        .reduceGroup(new GroupReduceFunction<Tuple3<Integer, String, Long>, Tuple3<Integer, String, Long>>() {
            @Override
            public void reduce(Iterable<Tuple3<Integer, String, Long>> values,
                               Collector<Tuple3<Integer, String, Long>> out) {
                long id = startIndex;
                for (Tuple3<Integer, String, Long> value : values) {
                    out.collect(Tuple3.of(value.f0, value.f1, id++)); // 归并
                }
            }
        });
}

0x05 输出模型

这部分分为两部分:

public class MultiStringIndexerModelDataConverter implements
    ModelDataConverter<Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>>, MultiStringIndexerModelData> {
    @Override
    public void save(Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>> modelData, Collector<Row> collector) {
        if (modelData.f0 != null) {
            collector.collect(Row.of(-1L, modelData.f0.toJson(), null));
        }
        modelData.f1.forEach(tuple -> {
            collector.collect(Row.of(tuple.f0.longValue(), tuple.f1, tuple.f2));
        });
    }  
}

tuple = {Tuple3@9405} "(0,tennis,0)"
 f0 = {Integer@9406} 0
 f1 = "tennis"
 f2 = {Long@9408} 0

0x06 预测

预测功能是在 ModelMapperAdapter 完成的。

public class ModelMapperAdapter extends RichMapFunction<Row, Row> implements Serializable {
    private final ModelMapper mapper;
    private final ModelSource modelSource;

    @Override
    public void open(Configuration parameters) throws Exception {
        List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
        this.mapper.loadModel(modelRows); //加载模型
    }

    @Override
    public Row map(Row row) throws Exception {
        return this.mapper.map(row); //预测
    }
}

6.1 加载模型

MultiStringIndexerModelDataConverter中我们会进行模型加载。

public MultiStringIndexerModelData load(List<Row> rows) {
    MultiStringIndexerModelData modelData = new MultiStringIndexerModelData();
    modelData.tokenAndIndex = new ArrayList<>();
    modelData.tokenNumber = new HashMap<>();
    for (Row row : rows) {
        long colIndex = (Long) row.getField(0);
        if (colIndex < 0L) { // 元数据
            modelData.meta = Params.fromJson((String) row.getField(1));
        } else { // 具体模型信息
            int columnIndex = ((Long) row.getField(0)).intValue();
            Long tokenIndex = Long.valueOf(String.valueOf(row.getField(2)));
            modelData.tokenAndIndex.add(Tuple3.of(columnIndex, (String) row.getField(1), tokenIndex));
            modelData.tokenNumber.merge(columnIndex, 1L, Long::sum); // 合并列数据个数
        }
    }

    // To ensure that every columns has token number.
    int numFields = 0;
    if (modelData.meta != null) {
        numFields = modelData.meta.get(HasSelectedCols.SELECTED_COLS).length;
    }
    for (int i = 0; i < numFields; i++) {
        modelData.tokenNumber.merge(i, 0L, Long::sum);
    }
    return modelData;
}

最后模型内容如下,其中 tokenNumber 表示每列的数据有几个,tokenAndIndex表示具体信息,比如(0,tennis,0),(0,basketball,1),(0,football,2) 就表示他们都是第一列的,basketball转换后的数据是 1:

modelData = {MultiStringIndexerModelData@9348} 
 meta = {Params@9440} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
 tokenAndIndex = {ArrayList@9360}  size = 6
  0 = {Tuple3@9472} "(0,football,2)"
  1 = {Tuple3@9511} "(0,tennis,0)"
  2 = {Tuple3@9512} "(1,zzz,1)"
  3 = {Tuple3@9513} "(1,hhh,0)"
  4 = {Tuple3@9514} "(0,basketball,1)"
  5 = {Tuple3@9515} "(1,can,2)"
 tokenNumber = {HashMap@9385}  size = 2
  {Integer@9507} 0 -> {Long@9508} 3
  {Integer@9509} 1 -> {Long@9508} 3
numFields = 2

6.2 预测

预测是在 MultiStringIndexerModelMapper 完成的。

// 假设输入是:row = {Row@9309} "football,can"
// 选择的列是:selectedColNames = {String[2]@9314}  0 = "a" 1 = "b"
// 模型映射器是:
this = {MultiStringIndexerModelMapper@9309} 
 indexMapper = {HashMap@9318}  size = 2
  {Integer@9357} 0 -> {HashMap@9314}  size = 3
   key = {Integer@9357} 0
    value = 0
   value = {HashMap@9314}  size = 3
    "basketball" -> {Long@9386} 1
    "football" -> {Long@9332} 2
    "tennis" -> {Long@9384} 0
  {Integer@9352} 1 -> {HashMap@9358}  size = 3
   key = {Integer@9352} 1
    value = 1
   value = {HashMap@9358}  size = 3
    "can" -> {Long@9332} 2
    "hhh" -> {Long@9384} 0
    "zzz" -> {Long@9386} 1

则经历过下列代码,最后就可以进行预测

public Row map(Row row) throws Exception {
    Row result = new Row(selectedColNames.length);
    for (int i = 0; i < selectedColNames.length; i++) {
        Map<String, Long> mapper = indexMapper.get(i);
        int colIdxInData = selectedColIndicesInData[i];
        Object val = row.getField(colIdxInData);
        String key = val == null ? null : String.valueOf(val);
        Long index = mapper.get(key);
        if (index != null) {
            result.setField(i, index); // 我们主要执行在这里
        } else {
        }
    }
  
// 最后预测结果是:
row = {Row@9308} "football,can"
result = {Row@9313} "2,2"
    
    return outputColsHelper.getResultRow(row, result);
}

0xFF 参考

Spark之特征预处理

标签:Alink,football,Tuple3,public,源码,MultiStringIndexer,new,Long,Row
来源: https://www.cnblogs.com/rossiXYZ/p/13429876.html