package org.genericsystem.cv.nn;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.transferlearning.TransferLearning;
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.model.VGG16;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/genericsystem/cv/nn/AdaptedVGG16MultiDataSet.class */
public class AdaptedVGG16MultiDataSet {
    private static final Logger log = LoggerFactory.getLogger(AdaptedVGG16MultiDataSet.class);
    private static int height = 224;
    private static int width = 224;
    private static int channels = 3;
    private static String featurizedLayer = "block5_pool";

    public static void main(String[] strArr) throws Exception {
        System.setProperty("org.bytedeco.javacpp.maxphysicalbytes", "8G");
        CudaEnvironment.getInstance().getConfiguration().setMaximumDeviceCacheableLength(1073741824L).setMaximumDeviceCache(6442450944L).setMaximumHostCacheableLength(1073741824L).setMaximumHostCache(6442450944L);
        String[] strArr2 = BaseImageLoader.ALLOWED_FORMATS;
        Random random = new Random(123);
        FileSplit fileSplit = new FileSplit(new File(System.getProperty("user.dir"), "training-grouped-augmented2"), strArr2, random);
        ParentPathLabelGenerator parentPathLabelGenerator = new ParentPathLabelGenerator();
        InputSplit[] sample = fileSplit.sample(new BalancedPathFilter(random, strArr2, (PathLabelGenerator) null, 0, 0, 1000, 0, new String[0]), new double[]{0.7d, 0.15d, 0.15d});
        InputSplit inputSplit = sample[0];
        InputSplit inputSplit2 = sample[1];
        InputSplit inputSplit3 = sample[2];
        log.debug("trainData: {}, validData: {}, testData: {}", new Object[]{Long.valueOf(inputSplit.length()), Long.valueOf(inputSplit2.length()), Long.valueOf(inputSplit3.length())});
        RecordReader imageRecordReader = new ImageRecordReader(height, width, channels);
        RecordReader imageFeaturesRecordReader = new ImageFeaturesRecordReader(height, width, channels, null, null);
        RecordReader imageClassRecordReader = new ImageClassRecordReader(height, width, channels, parentPathLabelGenerator);
        imageClassRecordReader.initialize(inputSplit);
        List labels = imageClassRecordReader.getLabels();
        int numLabels = imageClassRecordReader.numLabels();
        ComputationGraph build = new TransferLearning.GraphBuilder(new VGG16().initPretrained(PretrainedType.IMAGENET)).fineTuneConfiguration(new FineTuneConfiguration.Builder().gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).learningRate(Double.valueOf(0.005d)).regularization(true).build()).addInputs(new String[]{"features"}).setFeatureExtractor(new String[]{featurizedLayer}).removeVertexKeepConnections("predictions").addLayer("fc3", new DenseLayer.Builder().activation(new ActivationLReLU(0.33d)).weightInit(WeightInit.RELU).dropOut(0.5d).nIn(11840).nOut(2048).build(), new String[]{"fc2", "features"}).addLayer("predictions", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2048).nOut(numLabels).weightInit(WeightInit.RELU).dropOut(0.5d).activation(new ActivationSoftmax()).build(), new String[]{"fc3"}).setOutputs(new String[]{"predictions"}).build();
        TransferLearningHelper transferLearningHelper = new TransferLearningHelper(build, new String[]{featurizedLayer});
        ComputationGraph unfrozenGraph = transferLearningHelper.unfrozenGraph();
        List asList = Arrays.asList(imageRecordReader, imageFeaturesRecordReader, imageClassRecordReader);
        List asList2 = Arrays.asList("image", "features", "output");
        saveFeaturized(getMultiDataSetIterator(asList, asList2, inputSplit, 4, numLabels), transferLearningHelper, "train");
        saveFeaturized(getMultiDataSetIterator(asList, asList2, inputSplit2, 4, numLabels), transferLearningHelper, "validation");
        saveFeaturized(getMultiDataSetIterator(asList, asList2, inputSplit3, 4, numLabels), transferLearningHelper, "test");
        new EarlyStoppingGraphFeaturizedTrainer((EarlyStoppingConfiguration<ComputationGraph>) new EarlyStoppingConfiguration.Builder().epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementEpochTerminationCondition(20), new MaxEpochsTerminationCondition(100)}).scoreCalculator(new DataSetLossCalculatorCG(getPresavedMultiIterator("validation"), true)).modelSaver(new LocalFileGraphSaver("/tmp")).evaluateEveryNEpochs(1).build(), transferLearningHelper, getPresavedMultiIterator("train")).fit();
        log.info("Model evaluation:\n{}", unfrozenGraph.evaluate(getPresavedMultiIterator("test"), labels).stats(true));
        File file = new File("AdaptedVGG16-" + System.currentTimeMillis() + ".zip");
        ModelSerializer.writeModel(build, file, true);
        log.info("Model saved to {}.", file);
    }

    private static MultiDataSetIterator getPresavedMultiIterator(String str) {
        return new AsyncMultiDataSetIterator(new ExistingMiniBatchMultiDataSetIterator(new File(str + "Folder"), "images-" + featurizedLayer + "-" + str + "-%d.bin"));
    }

    private static void saveFeaturized(MultiDataSetIterator multiDataSetIterator, TransferLearningHelper transferLearningHelper, String str) {
        int[] iArr = {0};
        multiDataSetIterator.forEachRemaining(multiDataSet -> {
            saveToDisk(transferLearningHelper.featurize((MultiDataSet) multiDataSet), iArr[0], str);
            iArr[0] = iArr[0] + 1;
        });
        multiDataSetIterator.reset();
    }

    private static void saveToDisk(MultiDataSet multiDataSet, int i, String str) {
        File file = new File(str + "Folder");
        if (i == 0) {
            file.mkdirs();
        }
        String str2 = "images-" + featurizedLayer + "-" + str + "-" + i + ".bin";
        try {
            multiDataSet.save(new File(file, str2));
        } catch (IOException e) {
            log.error("Exception while saving file {}.", e, str2);
        }
        log.info("Saved {} dataset #{}", str, Integer.valueOf(i));
    }

    public static MultiDataSetIterator getMultiDataSetIterator(List<RecordReader> list, List<String> list2, InputSplit inputSplit, int i, int i2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException("The lists of recordReaders and of names must have the same size. " + (list.size() + 1) + " recordReader(s), " + (list2.size() + 1) + " names given.");
        }
        RecordReaderMultiDataSetIterator.Builder builder = new RecordReaderMultiDataSetIterator.Builder(i);
        for (int i3 = 0; i3 < list.size(); i3++) {
            try {
                list.get(i3).initialize(inputSplit);
                String str = list2.get(i3);
                RecordReader recordReader = list.get(i3);
                builder.addReader(str, recordReader);
                if (recordReader instanceof ImageClassRecordReader) {
                    builder.addOutputOneHot(str, 0, i2);
                } else {
                    builder.addInput(str);
                }
            } catch (IOException e) {
                log.error("Impossible to initialize recordReader.", e);
            } catch (InterruptedException e2) {
                log.error("Initialization of recordReader interrupted.", e2);
            }
        }
        RecordReaderMultiDataSetIterator build = builder.build();
        MultiNormalizerMinMaxScaler multiNormalizerMinMaxScaler = new MultiNormalizerMinMaxScaler(-1.0d, 1.0d);
        multiNormalizerMinMaxScaler.fit(build);
        log.debug("Scaler fit");
        build.setPreProcessor(multiNormalizerMinMaxScaler);
        build.reset();
        return build;
    }
}
