/*
 * Decompiled with CFR 0.152.
 */
package org.genericsystem.cv.nn;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.PathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.model.VGG16;
import org.genericsystem.cv.nn.EarlyStoppingGraphFeaturizedTrainer;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.dataset.ExistingMiniBatchDataSetIterator;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AdaptedVGG16 {
    private static final Logger log = LoggerFactory.getLogger(AdaptedVGG16.class);
    private static String featurizedLayer = "block5_pool";

    public static void main(String[] args) throws Exception {
        System.setProperty("org.bytedeco.javacpp.maxphysicalbytes", "8G");
        CudaEnvironment.getInstance().getConfiguration().setMaximumDeviceCacheableLength(0x40000000L).setMaximumDeviceCache(0x180000000L).setMaximumHostCacheableLength(0x40000000L).setMaximumHostCache(0x180000000L);
        double learningRate = 0.005;
        int batchSize = 4;
        int nEpochs = 100;
        int height = 224;
        int width = 224;
        int channels = 3;
        int seed = 123;
        String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
        Random randNumGen = new Random(seed);
        File parentDir = new File(System.getProperty("user.dir"), "training");
        FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen);
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, (PathLabelGenerator)labelMaker, 0, 0, 100, 0, new String[0]);
        InputSplit[] filesInDirSplit = filesInDir.sample((PathFilter)pathFilter, new double[]{0.7, 0.15, 0.15});
        InputSplit trainData = filesInDirSplit[0];
        InputSplit validData = filesInDirSplit[1];
        InputSplit testData = filesInDirSplit[2];
        ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, (PathLabelGenerator)labelMaker);
        recordReader.initialize(trainData, null);
        List labels = recordReader.getLabels();
        int outputNum = recordReader.numLabels();
        VGG16 zooModel = new VGG16();
        ComputationGraph vgg16 = (ComputationGraph)zooModel.initPretrained(PretrainedType.IMAGENET);
        FineTuneConfiguration fineTuneConfig = new FineTuneConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).learningRate(Double.valueOf(learningRate)).regularization(true).build();
        ComputationGraph net = new TransferLearning.GraphBuilder(vgg16).fineTuneConfiguration(fineTuneConfig).setFeatureExtractor(new String[]{featurizedLayer}).removeVertexKeepConnections("predictions").addLayer("predictions", (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2048)).nOut(outputNum)).weightInit(WeightInit.RELU)).dropOut(0.5)).activation((IActivation)new ActivationSoftmax())).build(), new String[]{"fc3"}).build();
        TransferLearningHelper transferLearningHelper = new TransferLearningHelper(net, new String[]{featurizedLayer});
        ComputationGraph graph = transferLearningHelper.unfrozenGraph();
        graph.setListeners(new IterationListener[]{new ScoreIterationListener(10)});
        AdaptedVGG16.saveFeaturized(AdaptedVGG16.getDataSetIterator(recordReader, trainData, null, batchSize, outputNum), transferLearningHelper, "train");
        AdaptedVGG16.saveFeaturized(AdaptedVGG16.getDataSetIterator(recordReader, validData, null, batchSize, outputNum), transferLearningHelper, "validation");
        AdaptedVGG16.saveFeaturized(AdaptedVGG16.getDataSetIterator(recordReader, testData, null, batchSize, outputNum), transferLearningHelper, "test");
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder().evaluateEveryNEpochs(1).epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementEpochTerminationCondition(20), new MaxEpochsTerminationCondition(nEpochs)}).scoreCalculator((ScoreCalculator)new DataSetLossCalculatorCG(AdaptedVGG16.getPresavedIterator("validation"), true)).modelSaver((EarlyStoppingModelSaver)new LocalFileGraphSaver("/tmp")).build();
        EarlyStoppingGraphFeaturizedTrainer trainer = new EarlyStoppingGraphFeaturizedTrainer((EarlyStoppingConfiguration<ComputationGraph>)esConf, transferLearningHelper, AdaptedVGG16.getPresavedIterator("train"));
        trainer.fit();
        Evaluation eval = graph.evaluate(AdaptedVGG16.getPresavedIterator("test"), labels);
        log.info("Model evaluation:\n{}", (Object)eval.stats(true));
        File modelFile = new File("AdaptedVGG16-" + System.currentTimeMillis() + ".zip");
        ModelSerializer.writeModel((Model)net, (File)modelFile, (boolean)true);
        log.info("Model saved to {}.", (Object)modelFile);
    }

    private static DataSetIterator getPresavedIterator(String name) {
        ExistingMiniBatchDataSetIterator existingTestData = new ExistingMiniBatchDataSetIterator(new File(name + "Folder"), "images-" + featurizedLayer + "-" + name + "-%d.bin");
        AsyncDataSetIterator asyncTestIter = new AsyncDataSetIterator((DataSetIterator)existingTestData);
        return asyncTestIter;
    }

    private static void saveFeaturized(DataSetIterator dataIter, TransferLearningHelper transferLearningHelper, String name) {
        int dataSaved = 0;
        while (dataIter.hasNext()) {
            org.nd4j.linalg.dataset.DataSet currentFeaturized = transferLearningHelper.featurize((org.nd4j.linalg.dataset.DataSet)dataIter.next());
            AdaptedVGG16.saveToDisk((DataSet)currentFeaturized, dataSaved, name);
            ++dataSaved;
        }
    }

    private static void saveToDisk(DataSet currentFeaturized, int iterNum, String name) {
        File fileFolder = new File(name + "Folder");
        if (iterNum == 0) {
            fileFolder.mkdirs();
        }
        String fileName = "images-" + featurizedLayer + "-" + name + "-" + iterNum + ".bin";
        currentFeaturized.save(new File(fileFolder, fileName));
        log.info("Saved {} dataset #{}", (Object)name, (Object)iterNum);
    }

    protected static DataSetIterator getDataSetIterator(ImageRecordReader recordReader, InputSplit data, ImageTransform transform, int batchSize, int outputNum) {
        try {
            recordReader.initialize(data, transform);
        }
        catch (IOException e) {
            log.error("Impossible to initialize recordReader.", (Throwable)e);
        }
        RecordReaderDataSetIterator dataIter = new RecordReaderDataSetIterator((RecordReader)recordReader, batchSize, 1, outputNum);
        ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(-1.0, 1.0);
        scaler.fit((DataSetIterator)dataIter);
        dataIter.setPreProcessor((DataSetPreProcessor)scaler);
        dataIter.reset();
        return dataIter;
    }
}

