/*
 * Decompiled with CFR 0.152.
 */
package org.genericsystem.cv.nn;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.PathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.model.VGG16;
import org.genericsystem.cv.nn.EarlyStoppingGraphFeaturizedTrainer;
import org.genericsystem.cv.nn.ExistingMiniBatchMultiDataSetIterator;
import org.genericsystem.cv.nn.ImageClassRecordReader;
import org.genericsystem.cv.nn.ImageFeaturesRecordReader;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptedVGG16MultiDataSet {
    private static final Logger log = LoggerFactory.getLogger(AdaptedVGG16MultiDataSet.class);
    private static int height = 224;
    private static int width = 224;
    private static int channels = 3;
    private static String featurizedLayer = "block5_pool";

    public static void main(String[] args) throws Exception {
        System.setProperty("org.bytedeco.javacpp.maxphysicalbytes", "8G");
        CudaEnvironment.getInstance().getConfiguration().setMaximumDeviceCacheableLength(0x40000000L).setMaximumDeviceCache(0x180000000L).setMaximumHostCacheableLength(0x40000000L).setMaximumHostCache(0x180000000L);
        double learningRate = 0.005;
        int batchSize = 4;
        int nEpochs = 100;
        int seed = 123;
        String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
        Random randNumGen = new Random(seed);
        File parentDir = new File(System.getProperty("user.dir"), "training-grouped-augmented2");
        FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen);
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, null, 0, 0, 1000, 0, new String[0]);
        InputSplit[] filesInDirSplit = filesInDir.sample((PathFilter)pathFilter, new double[]{0.7, 0.15, 0.15});
        InputSplit trainData = filesInDirSplit[0];
        InputSplit validData = filesInDirSplit[1];
        InputSplit testData = filesInDirSplit[2];
        log.debug("trainData: {}, validData: {}, testData: {}", new Object[]{trainData.length(), validData.length(), testData.length()});
        ImageRecordReader recordReader = new ImageRecordReader(height, width, channels);
        ImageFeaturesRecordReader featuresReader = new ImageFeaturesRecordReader(height, width, channels, null, null);
        ImageClassRecordReader outputReader = new ImageClassRecordReader(height, width, channels, labelMaker);
        outputReader.initialize(trainData);
        List labels = outputReader.getLabels();
        int outputNum = outputReader.numLabels();
        VGG16 zooModel = new VGG16();
        ComputationGraph vgg16 = (ComputationGraph)zooModel.initPretrained(PretrainedType.IMAGENET);
        FineTuneConfiguration fineTuneConfig = new FineTuneConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).learningRate(Double.valueOf(learningRate)).regularization(true).build();
        ComputationGraph net = new TransferLearning.GraphBuilder(vgg16).fineTuneConfiguration(fineTuneConfig).addInputs(new String[]{"features"}).setFeatureExtractor(new String[]{featurizedLayer}).removeVertexKeepConnections("predictions").addLayer("fc3", (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().activation((IActivation)new ActivationLReLU(0.33))).weightInit(WeightInit.RELU)).dropOut(0.5)).nIn(11840)).nOut(2048)).build(), new String[]{"fc2", "features"}).addLayer("predictions", (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2048)).nOut(outputNum)).weightInit(WeightInit.RELU)).dropOut(0.5)).activation((IActivation)new ActivationSoftmax())).build(), new String[]{"fc3"}).setOutputs(new String[]{"predictions"}).build();
        TransferLearningHelper transferLearningHelper = new TransferLearningHelper(net, new String[]{featurizedLayer});
        ComputationGraph graph = transferLearningHelper.unfrozenGraph();
        List<RecordReader> readers = Arrays.asList(new RecordReader[]{recordReader, featuresReader, outputReader});
        List<String> names = Arrays.asList("image", "features", "output");
        AdaptedVGG16MultiDataSet.saveFeaturized(AdaptedVGG16MultiDataSet.getMultiDataSetIterator(readers, names, trainData, batchSize, outputNum), transferLearningHelper, "train");
        AdaptedVGG16MultiDataSet.saveFeaturized(AdaptedVGG16MultiDataSet.getMultiDataSetIterator(readers, names, validData, batchSize, outputNum), transferLearningHelper, "validation");
        AdaptedVGG16MultiDataSet.saveFeaturized(AdaptedVGG16MultiDataSet.getMultiDataSetIterator(readers, names, testData, batchSize, outputNum), transferLearningHelper, "test");
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder().epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementEpochTerminationCondition(20), new MaxEpochsTerminationCondition(nEpochs)}).scoreCalculator((ScoreCalculator)new DataSetLossCalculatorCG(AdaptedVGG16MultiDataSet.getPresavedMultiIterator("validation"), true)).modelSaver((EarlyStoppingModelSaver)new LocalFileGraphSaver("/tmp")).evaluateEveryNEpochs(1).build();
        EarlyStoppingGraphFeaturizedTrainer trainer = new EarlyStoppingGraphFeaturizedTrainer((EarlyStoppingConfiguration<ComputationGraph>)esConf, transferLearningHelper, AdaptedVGG16MultiDataSet.getPresavedMultiIterator("train"));
        trainer.fit();
        Evaluation eval = graph.evaluate(AdaptedVGG16MultiDataSet.getPresavedMultiIterator("test"), labels);
        log.info("Model evaluation:\n{}", (Object)eval.stats(true));
        File modelFile = new File("AdaptedVGG16-" + System.currentTimeMillis() + ".zip");
        ModelSerializer.writeModel((Model)net, (File)modelFile, (boolean)true);
        log.info("Model saved to {}.", (Object)modelFile);
    }

    private static MultiDataSetIterator getPresavedMultiIterator(String name) {
        ExistingMiniBatchMultiDataSetIterator existingTestData = new ExistingMiniBatchMultiDataSetIterator(new File(name + "Folder"), "images-" + featurizedLayer + "-" + name + "-%d.bin");
        AsyncMultiDataSetIterator asyncTestIter = new AsyncMultiDataSetIterator((MultiDataSetIterator)existingTestData);
        return asyncTestIter;
    }

    private static void saveFeaturized(MultiDataSetIterator dataIter, TransferLearningHelper transferLearningHelper, String name) {
        int[] dataSaved = new int[]{0};
        dataIter.forEachRemaining(mds -> {
            MultiDataSet currentFeaturized = transferLearningHelper.featurize((MultiDataSet)mds);
            AdaptedVGG16MultiDataSet.saveToDisk(currentFeaturized, dataSaved[0], name);
            dataSaved[0] = dataSaved[0] + 1;
        });
        dataIter.reset();
    }

    private static void saveToDisk(MultiDataSet currentFeaturized, int iterNum, String name) {
        File fileFolder = new File(name + "Folder");
        if (iterNum == 0) {
            fileFolder.mkdirs();
        }
        String fileName = "images-" + featurizedLayer + "-" + name + "-" + iterNum + ".bin";
        try {
            currentFeaturized.save(new File(fileFolder, fileName));
        }
        catch (IOException e) {
            log.error("Exception while saving file {}.", (Object)e, (Object)fileName);
        }
        log.info("Saved {} dataset #{}", (Object)name, (Object)iterNum);
    }

    public static MultiDataSetIterator getMultiDataSetIterator(List<RecordReader> recordReaders, List<String> names, InputSplit data, int batchSize, int outputNum) {
        if (recordReaders.size() != names.size()) {
            throw new IllegalArgumentException("The lists of recordReaders and of names must have the same size. " + (recordReaders.size() + 1) + " recordReader(s), " + (names.size() + 1) + " names given.");
        }
        RecordReaderMultiDataSetIterator.Builder builder = new RecordReaderMultiDataSetIterator.Builder(batchSize);
        try {
            for (int i = 0; i < recordReaders.size(); ++i) {
                recordReaders.get(i).initialize(data);
                String name = names.get(i);
                RecordReader reader = recordReaders.get(i);
                builder.addReader(name, reader);
                if (reader instanceof ImageClassRecordReader) {
                    builder.addOutputOneHot(name, 0, outputNum);
                    continue;
                }
                builder.addInput(name);
            }
        }
        catch (IOException e) {
            log.error("Impossible to initialize recordReader.", (Throwable)e);
        }
        catch (InterruptedException e) {
            log.error("Initialization of recordReader interrupted.", (Throwable)e);
        }
        RecordReaderMultiDataSetIterator iterator = builder.build();
        MultiNormalizerMinMaxScaler scaler = new MultiNormalizerMinMaxScaler(-1.0, 1.0);
        scaler.fit((MultiDataSetIterator)iterator);
        log.debug("Scaler fit");
        iterator.setPreProcessor((MultiDataSetPreProcessor)scaler);
        iterator.reset();
        return iterator;
    }
}

