/*
 * Decompiled with CFR 0.152.
 */
package org.genericsystem.cv.nn;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.PathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.WarpImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SimpleCNN {
    protected static final Logger log = LoggerFactory.getLogger(SimpleCNN.class);
    private static final int seed = 123;
    private static Random rng = new Random(123L);
    private static final String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
    public static final Random randNumGen = new Random(123L);
    private static int height = 250;
    private static int width = 200;
    private static int channels = 3;
    protected static int iterations = 1;

    public static void main(String[] args) throws Exception {
        double learningRate = 0.005;
        int batchSize = 4;
        int nEpochs = 100;
        File parentDir = new File(System.getProperty("user.dir"), "training");
        FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen);
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, (PathLabelGenerator)labelMaker, 0, 0, 100, 0, new String[0]);
        InputSplit[] filesInDirSplit = filesInDir.sample((PathFilter)pathFilter, new double[]{70.0, 15.0, 15.0});
        InputSplit trainData = filesInDirSplit[0];
        InputSplit validData = filesInDirSplit[1];
        InputSplit testData = filesInDirSplit[2];
        ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, (PathLabelGenerator)labelMaker);
        recordReader.initialize(trainData, null);
        List labels = recordReader.getLabels();
        int outputNum = recordReader.numLabels();
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).iterations(iterations).regularization(true).gradientNormalization(GradientNormalization.RenormalizeL2PerParamType).activation(Activation.RELU).l2(4.0E-4).learningRate(learningRate).weightInit(WeightInit.RELU).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.RMSPROP).list().layer(0, (Layer)new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{2, 2}, new int[]{2, 2}).name("inputLayer").nIn(channels).nOut(96).biasInit(0.0).build()).layer(1, (Layer)SimpleCNN.maxPool("maxpool1")).layer(2, (Layer)new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{1, 1}, new int[]{1, 1}).name("convLayer").nOut(256).biasInit(0.0).build()).layer(3, (Layer)SimpleCNN.maxPool("maxpool2")).layer(4, (Layer)new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1}, new int[]{1, 1}).name("convLayer2").nOut(256).biasInit(0.0).build()).layer(5, (Layer)SimpleCNN.maxPool("maxpool3")).layer(6, (Layer)((DenseLayer.Builder)new DenseLayer.Builder().nOut(512)).build()).layer(7, (Layer)((DenseLayer.Builder)new DenseLayer.Builder().nOut(256)).build()).layer(8, (Layer)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nOut(outputNum)).activation(Activation.SOFTMAX)).build()).setInputType(InputType.convolutional((int)height, (int)width, (int)channels)).build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new IterationListener[]{new ScoreIterationListener(10)});
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder().epochTerminationConditions(new EpochTerminationCondition[]{new MaxEpochsTerminationCondition(nEpochs)}).evaluateEveryNEpochs(1).epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementEpochTerminationCondition(20)}).scoreCalculator((ScoreCalculator)new DataSetLossCalculator(SimpleCNN.getDataSetIterator(recordReader, validData, null, batchSize, outputNum), true)).build();
        DataSetIterator dataIter = SimpleCNN.getDataSetIterator(recordReader, trainData, null, batchSize, outputNum);
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, dataIter);
        EarlyStoppingResult result = trainer.fit();
        FlipImageTransform flipTransform1 = new FlipImageTransform(rng);
        WarpImageTransform warpTransform = new WarpImageTransform(rng, 42.0f);
        List<ImageTransform> transforms = Arrays.asList(flipTransform1, warpTransform);
        for (ImageTransform transform : transforms) {
            System.out.print("\nTraining on transformation: " + transform.getClass().toString() + "\n\n");
            dataIter = SimpleCNN.getDataSetIterator(recordReader, trainData, transform, batchSize, outputNum);
            trainer = new EarlyStoppingTrainer(esConf, (MultiLayerNetwork)result.getBestModel(), dataIter);
            result = trainer.fit();
        }
        log.info("Evaluate model....");
        dataIter = SimpleCNN.getDataSetIterator(recordReader, testData, null, batchSize, outputNum);
        Evaluation eval = model.evaluate(dataIter, labels);
        log.info(eval.stats(true));
        File modelFile = new File("TrainedModel-" + System.currentTimeMillis() + ".zip");
        ModelSerializer.writeModel((Model)model, (File)modelFile, (boolean)true);
        log.info("Model saved to {}.", (Object)modelFile);
    }

    private static DataSetIterator getDataSetIterator(ImageRecordReader recordReader, InputSplit data, ImageTransform transform, int batchSize, int outputNum) {
        try {
            recordReader.initialize(data, transform);
        }
        catch (IOException e) {
            log.warn("Impossible to initialize recordReader", (Throwable)e);
        }
        RecordReaderDataSetIterator dataIter = new RecordReaderDataSetIterator((RecordReader)recordReader, batchSize, 1, outputNum);
        ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(-1.0, 1.0);
        scaler.fit((DataSetIterator)dataIter);
        dataIter.setPreProcessor((DataSetPreProcessor)scaler);
        return dataIter;
    }

    private static SubsamplingLayer maxPool(String name) {
        return ((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(new int[]{2, 2}, new int[]{2, 2}).name(name)).build();
    }
}

