编程语言
首页 > 编程语言> > java-如何在Deeplearning4j中使用自定义数据模型?

java-如何在Deeplearning4j中使用自定义数据模型?

作者:互联网

基本问题是尝试使用自定义数据模型创建要在deeplearning4j网络中使用的DataSetIterator.

我要使用的数据模型是一个Java类,其中包含一堆双打,这些双打是根据特定股票的报价创建的,例如时间戳,开盘价,开盘价,收盘价,高价,低价,交易量,技术指标1,技术指标2,等等
我查询互联网资源example(也来自同一站点的其他几个指标),该资源提供了json字符串,我将其转换为数据模型以便于访问并存储在sqlite数据库中.

现在,我有了这些数据模型的清单,我想用它们来训练LSTM网络,每个模型都是一个功能.根据Deeplearning4j文档和几个示例,使用训练数据的方法是使用描述为here的ETL流程创建一个DataSetIterator,然后供网络使用.

在没有先将它们转换为其他格式(例如CSV或其他文件)的情况下,我看不到使用任何提供的RecordReader转换数据模型的干净方法.我想避免这种情况,因为它会占用大量资源.似乎会有更好的方法来处理此简单案例.有什么更好的方法我只是想念吗?

解决方法:

伊桑!

首先,Deeplearning4j使用ND4j作为后端,因此最终必须将您的数据转换为INDArray对象才能在模型中使用.如果trianing数据是两个双精度数组,即inputsArray和desiredOutputsArray,则可以执行以下操作:

INDArray inputs = Nd4j.create(inputsArray, new int[]{numSamples, inputDim});
INDArray desiredOutputs = Nd4j.create(desiredOutputsArray, new int[]{numSamples, outputDim});

然后,您可以直接使用这些向量训练模型:

for (int epoch = 0; epoch < nEpochs; epoch++)
    model.fit(inputs, desiredOutputs);

另外,您可以创建一个DataSet对象并将其用于训练:

DataSet ds = new DataSet(inputs, desiredOutputs);
for (int epoch = 0; epoch < nEpochs; epoch++)
    model.fit(ds);

但是,创建自定义迭代器是最安全的方法,特别是在较大的集合中,因为它可以让您更好地控制数据并使事情井井有条.

在您的DataSetIterator实现中,您必须传递数据,在next()方法的实现中,您应返回一个包含下一批训练数据的DataSet对象.它看起来像这样:

public class MyCustomIterator implements DataSetIterator {
    private INDArray inputs, desiredOutputs;
    private int itPosition = 0; // the iterator position in the set.

    public MyCustomIterator(float[] inputsArray,
                            float[] desiredOutputsArray,
                            int numSamples,
                            int inputDim,
                            int outputDim) {
        inputs = Nd4j.create(inputsArray, new int[]{numSamples, inputDim});
        desiredOutputs = Nd4j.create(desiredOutputsArray, new int[]{numSamples, outputDim});
    }

    public DataSet next(int num) {
        // get a view containing the next num samples and desired outs.
        INDArray dsInput = inputs.get(
            NDArrayIndex.interval(itPosition, itPosition + num),
            NDArrayIndex.all());
        INDArray dsDesired = desiredOutputs.get(
            NDArrayIndex.interval(itPosition, itPosition + num),
            NDArrayIndex.all());

        itPosition += num;

        return new DataSet(dsInput, dsDesired);
    }

    // implement the remaining virtual methods...

}

您在上面看到的NDArrayIndex方法用于访问INDArray的部分.然后,您可以将其用于培训:

MyCustomIterator it = new MyCustomIterator(
    inputs,
    desiredOutputs,
    numSamples,
    inputDim,
    outputDim);

for (int epoch = 0; epoch < nEpochs; epoch++)
    model.fit(it);

This example对您特别有用,因为它实现了LSTM网络,并且具有自定义的迭代器实现(可以作为实现其余方法的指南).另外,有关NDArray的更多信息,this是有用的.它提供了有关创建,修改和访问NDArray的各个部分的详细信息.

标签:deep-learning,lstm,deeplearning4j,java
来源: https://codeday.me/bug/20191025/1927988.html