package org.genericsystem.cv.nn;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.model.VGG16;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.ExistingMiniBatchDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/genericsystem/cv/nn/AdaptedVGG16.class */
public class AdaptedVGG16 {
    private static final Logger log = LoggerFactory.getLogger(AdaptedVGG16.class);
    private static String featurizedLayer = "block5_pool";

    public static void main(String[] strArr) throws Exception {
        System.setProperty("org.bytedeco.javacpp.maxphysicalbytes", "8G");
        CudaEnvironment.getInstance().getConfiguration().setMaximumDeviceCacheableLength(1073741824L).setMaximumDeviceCache(6442450944L).setMaximumHostCacheableLength(1073741824L).setMaximumHostCache(6442450944L);
        String[] strArr2 = BaseImageLoader.ALLOWED_FORMATS;
        Random random = new Random(123);
        FileSplit fileSplit = new FileSplit(new File(System.getProperty("user.dir"), "training"), strArr2, random);
        ParentPathLabelGenerator parentPathLabelGenerator = new ParentPathLabelGenerator();
        InputSplit[] sample = fileSplit.sample(new BalancedPathFilter(random, strArr2, parentPathLabelGenerator, 0, 0, 100, 0, new String[0]), new double[]{0.7d, 0.15d, 0.15d});
        InputSplit inputSplit = sample[0];
        InputSplit inputSplit2 = sample[1];
        InputSplit inputSplit3 = sample[2];
        ImageRecordReader imageRecordReader = new ImageRecordReader(224, 224, 3, parentPathLabelGenerator);
        imageRecordReader.initialize(inputSplit, (ImageTransform) null);
        List labels = imageRecordReader.getLabels();
        int numLabels = imageRecordReader.numLabels();
        ComputationGraph build = new TransferLearning.GraphBuilder(new VGG16().initPretrained(PretrainedType.IMAGENET)).fineTuneConfiguration(new FineTuneConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).learningRate(Double.valueOf(0.005d)).regularization(true).build()).setFeatureExtractor(new String[]{featurizedLayer}).removeVertexKeepConnections("predictions").addLayer("predictions", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2048).nOut(numLabels).weightInit(WeightInit.RELU).dropOut(0.5d).activation(new ActivationSoftmax()).build(), new String[]{"fc3"}).build();
        TransferLearningHelper transferLearningHelper = new TransferLearningHelper(build, new String[]{featurizedLayer});
        ComputationGraph unfrozenGraph = transferLearningHelper.unfrozenGraph();
        unfrozenGraph.setListeners(new IterationListener[]{new ScoreIterationListener(10)});
        saveFeaturized(getDataSetIterator(imageRecordReader, inputSplit, null, 4, numLabels), transferLearningHelper, "train");
        saveFeaturized(getDataSetIterator(imageRecordReader, inputSplit2, null, 4, numLabels), transferLearningHelper, "validation");
        saveFeaturized(getDataSetIterator(imageRecordReader, inputSplit3, null, 4, numLabels), transferLearningHelper, "test");
        new EarlyStoppingGraphFeaturizedTrainer((EarlyStoppingConfiguration<ComputationGraph>) new EarlyStoppingConfiguration.Builder().evaluateEveryNEpochs(1).epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementEpochTerminationCondition(20), new MaxEpochsTerminationCondition(100)}).scoreCalculator(new DataSetLossCalculatorCG(getPresavedIterator("validation"), true)).modelSaver(new LocalFileGraphSaver("/tmp")).build(), transferLearningHelper, getPresavedIterator("train")).fit();
        log.info("Model evaluation:\n{}", unfrozenGraph.evaluate(getPresavedIterator("test"), labels).stats(true));
        File file = new File("AdaptedVGG16-" + System.currentTimeMillis() + ".zip");
        ModelSerializer.writeModel(build, file, true);
        log.info("Model saved to {}.", file);
    }

    private static DataSetIterator getPresavedIterator(String str) {
        return new AsyncDataSetIterator(new ExistingMiniBatchDataSetIterator(new File(str + "Folder"), "images-" + featurizedLayer + "-" + str + "-%d.bin"));
    }

    private static void saveFeaturized(DataSetIterator dataSetIterator, TransferLearningHelper transferLearningHelper, String str) {
        int i = 0;
        while (dataSetIterator.hasNext()) {
            saveToDisk(transferLearningHelper.featurize((DataSet) dataSetIterator.next()), i, str);
            i++;
        }
    }

    private static void saveToDisk(org.nd4j.linalg.dataset.api.DataSet dataSet, int i, String str) {
        File file = new File(str + "Folder");
        if (i == 0) {
            file.mkdirs();
        }
        dataSet.save(new File(file, "images-" + featurizedLayer + "-" + str + "-" + i + ".bin"));
        log.info("Saved {} dataset #{}", str, Integer.valueOf(i));
    }

    protected static DataSetIterator getDataSetIterator(ImageRecordReader imageRecordReader, InputSplit inputSplit, ImageTransform imageTransform, int i, int i2) {
        try {
            imageRecordReader.initialize(inputSplit, imageTransform);
        } catch (IOException e) {
            log.error("Impossible to initialize recordReader.", e);
        }
        RecordReaderDataSetIterator recordReaderDataSetIterator = new RecordReaderDataSetIterator(imageRecordReader, i, 1, i2);
        ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(-1.0d, 1.0d);
        imagePreProcessingScaler.fit(recordReaderDataSetIterator);
        recordReaderDataSetIterator.setPreProcessor(imagePreProcessingScaler);
        recordReaderDataSetIterator.reset();
        return recordReaderDataSetIterator;
    }
}
