package org.genericsystem.cv;

import java.util.ArrayList;
import java.util.Arrays;
import org.genericsystem.cv.utils.NativeLibraryLoader;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.MatOfFloat;
import org.opencv.core.MatOfInt;
import org.opencv.core.TermCriteria;
import org.opencv.ml.ANN_MLP;

/* loaded from: input_file:org/genericsystem/cv/Mlp.class */
public class Mlp {
    final int MAX_DATA = 1000;
    ANN_MLP mlp = ANN_MLP.create();
    int input;
    int output;
    ArrayList<float[]> train;
    ArrayList<float[]> label;
    MatOfFloat result;

    public static void main(String[] strArr) {
        Mlp mlp = new Mlp(2, 2);
        mlp.addData(new float[]{0.0f, 0.0f}, new float[]{1.0f, 0.0f});
        mlp.addData(new float[]{1.0f, 1.0f}, new float[]{0.0f, 1.0f});
        mlp.addData(new float[]{0.0f, 1.0f}, new float[]{1.0f, 0.0f});
        mlp.addData(new float[]{1.0f, 0.0f}, new float[]{1.0f, 0.0f});
        mlp.train();
        System.out.println("0 xor 0, 0 or 0 = " + Arrays.toString(mlp.predict(new float[]{0.0f, 0.0f})));
        System.out.println("1 xor 1, 1 or 1 = " + Arrays.toString(mlp.predict(new float[]{1.0f, 1.0f})));
        System.out.println("0 xor 1, 0 or 1 = " + Arrays.toString(mlp.predict(new float[]{0.0f, 1.0f})));
        System.out.println("1 xor 0, 1 or 0 = " + Arrays.toString(mlp.predict(new float[]{1.0f, 0.0f})));
    }

    public Mlp(int i, int i2) {
        this.input = i;
        this.output = i2;
        this.mlp.setLayerSizes(new MatOfInt(new int[]{i, 8, i2}));
        this.mlp.setActivationFunction(1);
        this.mlp.setTermCriteria(new TermCriteria(3, 100000, 9.999999747378752E-6d));
        this.result = new MatOfFloat();
        this.train = new ArrayList<>();
        this.label = new ArrayList<>();
    }

    void addData(float[] fArr, float[] fArr2) {
        if (fArr.length == this.input && this.train.size() < 1000) {
            this.train.add(fArr);
            this.label.add(fArr2);
        }
    }

    int getCount() {
        return this.train.size();
    }

    void train() {
        float[][] fArr = new float[this.train.size()][this.input];
        for (int i = 0; i < this.train.size(); i++) {
            for (int i2 = 0; i2 < this.train.get(i).length; i2++) {
                fArr[i][i2] = this.train.get(i)[i2];
            }
        }
        Mat mat = new Mat(this.label.size(), this.label.get(0).length, CvType.CV_32FC1);
        for (int i3 = 0; i3 < this.label.size(); i3++) {
            for (int i4 = 0; i4 < this.label.get(0).length; i4++) {
                mat.put(i3, i4, new double[]{this.label.get(i3)[i4]});
            }
        }
        float[] flatten = flatten(fArr);
        Mat mat2 = new Mat(this.train.size(), this.input, CvType.CV_32FC1);
        mat2.put(0, 0, flatten);
        this.mlp.train(mat2, 0, mat);
        mat2.release();
        mat.release();
        this.train.clear();
        this.label.clear();
    }

    float[] predict(float[] fArr) {
        if (fArr.length != this.input) {
            throw new IllegalStateException();
        }
        Mat mat = new Mat(1, this.input, CvType.CV_32FC1);
        mat.put(0, 0, fArr);
        this.mlp.predict(mat, this.result, 0);
        return getResult();
    }

    float[] getResult() {
        return this.result.toArray();
    }

    float[] flatten(float[][] fArr) {
        if (fArr.length == 0) {
            return new float[0];
        }
        int length = fArr.length;
        int length2 = fArr[0].length;
        float[] fArr2 = new float[length * length2];
        int i = 0;
        for (float[] fArr3 : fArr) {
            for (int i2 = 0; i2 < length2; i2++) {
                fArr2[i] = fArr3[i2];
                i++;
            }
        }
        return fArr2;
    }

    static {
        NativeLibraryLoader.load();
    }
}
