package org.genericsystem.cv.nn;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.WarpImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/genericsystem/cv/nn/SimpleCNN.class */
public class SimpleCNN {
    private static final int seed = 123;
    protected static final Logger log = LoggerFactory.getLogger(SimpleCNN.class);
    private static Random rng = new Random(123);
    private static final String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
    public static final Random randNumGen = new Random(123);
    private static int height = 250;
    private static int width = 200;
    private static int channels = 3;
    protected static int iterations = 1;

    public static void main(String[] strArr) throws Exception {
        FileSplit fileSplit = new FileSplit(new File(System.getProperty("user.dir"), "training"), allowedExtensions, randNumGen);
        ParentPathLabelGenerator parentPathLabelGenerator = new ParentPathLabelGenerator();
        InputSplit[] sample = fileSplit.sample(new BalancedPathFilter(randNumGen, allowedExtensions, parentPathLabelGenerator, 0, 0, 100, 0, new String[0]), new double[]{70.0d, 15.0d, 15.0d});
        InputSplit inputSplit = sample[0];
        InputSplit inputSplit2 = sample[1];
        InputSplit inputSplit3 = sample[2];
        ImageRecordReader imageRecordReader = new ImageRecordReader(height, width, channels, parentPathLabelGenerator);
        imageRecordReader.initialize(inputSplit, (ImageTransform) null);
        List labels = imageRecordReader.getLabels();
        int numLabels = imageRecordReader.numLabels();
        MultiLayerConfiguration build = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).regularization(true).gradientNormalization(GradientNormalization.RenormalizeL2PerParamType).activation(Activation.RELU).l2(4.0E-4d).learningRate(0.005d).weightInit(WeightInit.RELU).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.RMSPROP).list().layer(0, new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{2, 2}, new int[]{2, 2}).name("inputLayer").nIn(channels).nOut(96).biasInit(0.0d).build()).layer(1, maxPool("maxpool1")).layer(2, new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{1, 1}, new int[]{1, 1}).name("convLayer").nOut(256).biasInit(0.0d).build()).layer(3, maxPool("maxpool2")).layer(4, new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1}, new int[]{1, 1}).name("convLayer2").nOut(256).biasInit(0.0d).build()).layer(5, maxPool("maxpool3")).layer(6, new DenseLayer.Builder().nOut(512).build()).layer(7, new DenseLayer.Builder().nOut(256).build()).layer(8, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nOut(numLabels).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutional(height, width, channels)).build();
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(build);
        multiLayerNetwork.init();
        multiLayerNetwork.setListeners(new IterationListener[]{new ScoreIterationListener(10)});
        EarlyStoppingConfiguration build2 = new EarlyStoppingConfiguration.Builder().epochTerminationConditions(new EpochTerminationCondition[]{new MaxEpochsTerminationCondition(100)}).evaluateEveryNEpochs(1).epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementEpochTerminationCondition(20)}).scoreCalculator(new DataSetLossCalculator(getDataSetIterator(imageRecordReader, inputSplit2, null, 4, numLabels), true)).build();
        EarlyStoppingResult fit = new EarlyStoppingTrainer(build2, build, getDataSetIterator(imageRecordReader, inputSplit, null, 4, numLabels)).fit();
        for (ImageTransform imageTransform : Arrays.asList(new FlipImageTransform(rng), new WarpImageTransform(rng, 42.0f))) {
            System.out.print("\nTraining on transformation: " + imageTransform.getClass().toString() + "\n\n");
            fit = new EarlyStoppingTrainer(build2, fit.getBestModel(), getDataSetIterator(imageRecordReader, inputSplit, imageTransform, 4, numLabels)).fit();
        }
        log.info("Evaluate model....");
        log.info(multiLayerNetwork.evaluate(getDataSetIterator(imageRecordReader, inputSplit3, null, 4, numLabels), labels).stats(true));
        File file = new File("TrainedModel-" + System.currentTimeMillis() + ".zip");
        ModelSerializer.writeModel(multiLayerNetwork, file, true);
        log.info("Model saved to {}.", file);
    }

    private static DataSetIterator getDataSetIterator(ImageRecordReader imageRecordReader, InputSplit inputSplit, ImageTransform imageTransform, int i, int i2) {
        try {
            imageRecordReader.initialize(inputSplit, imageTransform);
        } catch (IOException e) {
            log.warn("Impossible to initialize recordReader", e);
        }
        RecordReaderDataSetIterator recordReaderDataSetIterator = new RecordReaderDataSetIterator(imageRecordReader, i, 1, i2);
        ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(-1.0d, 1.0d);
        imagePreProcessingScaler.fit(recordReaderDataSetIterator);
        recordReaderDataSetIterator.setPreProcessor(imagePreProcessingScaler);
        return recordReaderDataSetIterator;
    }

    private static SubsamplingLayer maxPool(String str) {
        return new SubsamplingLayer.Builder(new int[]{2, 2}, new int[]{2, 2}).name(str).build();
    }
}
