deeplearning4j
作者:互联网
1.概述
一个基于java实现的深度学习框架,用于深度学习神经网络的搭建和模型训练。
2.demo
public class Demo { public static void main(String[] args) throws Exception { int height = 28; int width = 28; int channels = 1; // 这里有没有复杂的识别,没有分成红绿蓝三个通道 int outputNum = 10; // 有十个数字,所以输出为10 int batchSize = 54;//每次迭代取54张小批量来训练,可以查阅神经网络的mini batch相关优化,也就是小批量求平均梯度 int nEpochs = 1;//整个样本集只训练一次 int iterations = 1; int seed = 1234; Random randNumGen = new Random(seed); File trainData = new File(basePath + "/mnist_png/training"); FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); //以父级目录名作为分类的标签名 ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);//构造图片读取类 trainRR.initialize(trainSplit); DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum); // 把像素值区间 0-255 压缩到0-1 区间 DataNormalization scaler = new ImagePreProcessingScaler(0, 1); scaler.fit(trainIter); trainIter.setPreProcessor(scaler); // 向量化测试集 File testData = new File(basePath + "/mnist_png/testing"); FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker); testRR.initialize(testSplit); DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum); testIter.setPreProcessor(scaler); Map<Integer, Double> lrs = new HashMap<>(); lrs.put(0, 0.06); lrs.put(200, 0.05); lrs.put(600, 0.028); lrs.put(800, 0.0060); lrs.put(1000, 0.001); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) .regularization(true).l2(0.0005) .learningRate(.01) .learningRateDecayPolicy(LearningRatePolicy.Schedule) .learningRateSchedule(lrSchedule) .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROVS) .list() .layer(0, new ConvolutionLayer.Builder(5, 5) .nIn(channels) .stride(1, 1) .nOut(20) .activation(Activation.IDENTITY) .build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(2, new ConvolutionLayer.Builder(5, 5) .stride(1, 1) .nOut(50) .activation(Activation.IDENTITY) .build()) .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(4, new DenseLayer.Builder().activation(Activation.RELU) .nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum) .activation(Activation.SOFTMAX) .build()) .setInputType(InputType.convolutionalFlat(28, 28, 1)) .backprop(true).pretrain(false).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(10)); log.debug("Total num of params: {}", net.numParams()); // 评估测试集 for (int i = 0; i < nEpochs; i++) { net.fit(trainIter); Evaluation eval = net.evaluate(testIter); log.info(eval.stats()); trainIter.reset(); testIter.reset(); } ModelSerializer.writeModel(net, new File(basePath + "/minist-model.zip"), true); } }
标签:int,Builder,deeplearning4j,lrs,build,new,net 来源: https://www.cnblogs.com/yangyang12138/p/13649367.html