java-如何在Deeplearning4j中使用自定义数据模型?
作者:互联网
基本问题是尝试使用自定义数据模型创建要在deeplearning4j网络中使用的DataSetIterator.
我要使用的数据模型是一个Java类,其中包含一堆双打,这些双打是根据特定股票的报价创建的,例如时间戳,开盘价,开盘价,收盘价,高价,低价,交易量,技术指标1,技术指标2,等等
我查询互联网资源example(也来自同一站点的其他几个指标),该资源提供了json字符串,我将其转换为数据模型以便于访问并存储在sqlite数据库中.
现在,我有了这些数据模型的清单,我想用它们来训练LSTM网络,每个模型都是一个功能.根据Deeplearning4j文档和几个示例,使用训练数据的方法是使用描述为here的ETL流程创建一个DataSetIterator,然后供网络使用.
在没有先将它们转换为其他格式(例如CSV或其他文件)的情况下,我看不到使用任何提供的RecordReader转换数据模型的干净方法.我想避免这种情况,因为它会占用大量资源.似乎会有更好的方法来处理此简单案例.有什么更好的方法我只是想念吗?
解决方法:
伊桑!
首先,Deeplearning4j使用ND4j作为后端,因此最终必须将您的数据转换为INDArray对象才能在模型中使用.如果trianing数据是两个双精度数组,即inputsArray和desiredOutputsArray,则可以执行以下操作:
INDArray inputs = Nd4j.create(inputsArray, new int[]{numSamples, inputDim});
INDArray desiredOutputs = Nd4j.create(desiredOutputsArray, new int[]{numSamples, outputDim});
然后,您可以直接使用这些向量训练模型:
for (int epoch = 0; epoch < nEpochs; epoch++)
model.fit(inputs, desiredOutputs);
另外,您可以创建一个DataSet对象并将其用于训练:
DataSet ds = new DataSet(inputs, desiredOutputs);
for (int epoch = 0; epoch < nEpochs; epoch++)
model.fit(ds);
但是,创建自定义迭代器是最安全的方法,特别是在较大的集合中,因为它可以让您更好地控制数据并使事情井井有条.
在您的DataSetIterator实现中,您必须传递数据,在next()方法的实现中,您应返回一个包含下一批训练数据的DataSet对象.它看起来像这样:
public class MyCustomIterator implements DataSetIterator {
private INDArray inputs, desiredOutputs;
private int itPosition = 0; // the iterator position in the set.
public MyCustomIterator(float[] inputsArray,
float[] desiredOutputsArray,
int numSamples,
int inputDim,
int outputDim) {
inputs = Nd4j.create(inputsArray, new int[]{numSamples, inputDim});
desiredOutputs = Nd4j.create(desiredOutputsArray, new int[]{numSamples, outputDim});
}
public DataSet next(int num) {
// get a view containing the next num samples and desired outs.
INDArray dsInput = inputs.get(
NDArrayIndex.interval(itPosition, itPosition + num),
NDArrayIndex.all());
INDArray dsDesired = desiredOutputs.get(
NDArrayIndex.interval(itPosition, itPosition + num),
NDArrayIndex.all());
itPosition += num;
return new DataSet(dsInput, dsDesired);
}
// implement the remaining virtual methods...
}
您在上面看到的NDArrayIndex方法用于访问INDArray的部分.然后,您可以将其用于培训:
MyCustomIterator it = new MyCustomIterator(
inputs,
desiredOutputs,
numSamples,
inputDim,
outputDim);
for (int epoch = 0; epoch < nEpochs; epoch++)
model.fit(it);
This example对您特别有用,因为它实现了LSTM网络,并且具有自定义的迭代器实现(可以作为实现其余方法的指南).另外,有关NDArray的更多信息,this是有用的.它提供了有关创建,修改和访问NDArray的各个部分的详细信息.
标签:deep-learning,lstm,deeplearning4j,java 来源: https://codeday.me/bug/20191025/1927988.html