package org.genericsystem.cv.nn;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.ImageTransform;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.util.ModelSerializer;
import org.genericsystem.cv.utils.NativeLibraryLoader;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/genericsystem/cv/nn/TestNet.class */
public class TestNet {
    private static final Logger log = LoggerFactory.getLogger(TestNet.class);

    public static void classifyImage(File file, List<String> list) {
        ComputationGraph computationGraph = getComputationGraph(new File("AdaptedVGG16/AdaptedVGG16-grouped-acc-1.zip"));
        NativeImageLoader nativeImageLoader = new NativeImageLoader(224, 224, 3);
        ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(-1.0d, 1.0d);
        try {
            INDArray asMatrix = nativeImageLoader.asMatrix(file);
            imagePreProcessingScaler.transform(asMatrix);
            log.info("Result: {}.", computationGraph.outputSingle(new INDArray[]{asMatrix}));
        } catch (IOException e) {
            throw new RuntimeException("Impossible to load image " + file, e);
        }
    }

    public static void main(String[] strArr) {
        FileSplit fileSplit = new FileSplit(new File(System.getProperty("user.dir"), "data/training-grouped-augmented"), BaseImageLoader.ALLOWED_FORMATS);
        ParentPathLabelGenerator parentPathLabelGenerator = new ParentPathLabelGenerator();
        InputSplit inputSplit = fileSplit.sample(new BalancedPathFilter(new Random(123L), BaseImageLoader.ALLOWED_FORMATS, parentPathLabelGenerator, 0, 0, 100, 0, new String[0]), new double[0])[0];
        try {
            ImageRecordReader imageRecordReader = new ImageRecordReader(224, 224, 3, parentPathLabelGenerator);
            Throwable th = null;
            try {
                imageRecordReader.initialize(inputSplit, (ImageTransform) null);
                classifyImage(new File("data/validation/id-fr-front/dimage-1.png"), imageRecordReader.getLabels());
                if (imageRecordReader != null) {
                    if (0 != 0) {
                        try {
                            imageRecordReader.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        imageRecordReader.close();
                    }
                }
            } finally {
            }
        } catch (IOException e) {
            log.error("Impossible to load data", e);
        }
    }

    public static ComputationGraph getComputationGraph(File file) {
        try {
            return ModelSerializer.restoreComputationGraph(file);
        } catch (IOException e) {
            throw new RuntimeException("Impossible to load model from disk.", e);
        }
    }

    static {
        NativeLibraryLoader.load();
    }
}
