/*
 * Decompiled with CFR 0.152.
 */
package org.genericsystem.cv;

import java.util.ArrayList;
import java.util.Arrays;
import org.genericsystem.cv.utils.NativeLibraryLoader;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.MatOfFloat;
import org.opencv.core.MatOfInt;
import org.opencv.core.TermCriteria;
import org.opencv.ml.ANN_MLP;

public class Mlp {
    final int MAX_DATA = 1000;
    ANN_MLP mlp;
    int input;
    int output;
    ArrayList<float[]> train;
    ArrayList<float[]> label;
    MatOfFloat result;

    public static void main(String[] args) {
        Mlp mlp = new Mlp(2, 2);
        mlp.addData(new float[]{0.0f, 0.0f}, new float[]{1.0f, 0.0f});
        mlp.addData(new float[]{1.0f, 1.0f}, new float[]{0.0f, 1.0f});
        mlp.addData(new float[]{0.0f, 1.0f}, new float[]{1.0f, 0.0f});
        mlp.addData(new float[]{1.0f, 0.0f}, new float[]{1.0f, 0.0f});
        mlp.train();
        System.out.println("0 xor 0, 0 or 0 = " + Arrays.toString(mlp.predict(new float[]{0.0f, 0.0f})));
        System.out.println("1 xor 1, 1 or 1 = " + Arrays.toString(mlp.predict(new float[]{1.0f, 1.0f})));
        System.out.println("0 xor 1, 0 or 1 = " + Arrays.toString(mlp.predict(new float[]{0.0f, 1.0f})));
        System.out.println("1 xor 0, 1 or 0 = " + Arrays.toString(mlp.predict(new float[]{1.0f, 0.0f})));
    }

    public Mlp(int input, int output) {
        this.input = input;
        this.output = output;
        this.mlp = ANN_MLP.create();
        MatOfInt m1 = new MatOfInt(new int[]{input, 8, output});
        this.mlp.setLayerSizes((Mat)m1);
        this.mlp.setActivationFunction(1);
        this.mlp.setTermCriteria(new TermCriteria(3, 100000, (double)1.0E-5f));
        this.result = new MatOfFloat();
        this.train = new ArrayList();
        this.label = new ArrayList();
    }

    void addData(float[] t, float[] l) {
        if (t.length != this.input) {
            return;
        }
        if (this.train.size() >= 1000) {
            return;
        }
        this.train.add(t);
        this.label.add(l);
    }

    int getCount() {
        return this.train.size();
    }

    void train() {
        float[][] tr = new float[this.train.size()][this.input];
        for (int i = 0; i < this.train.size(); ++i) {
            for (int j = 0; j < this.train.get(i).length; ++j) {
                tr[i][j] = this.train.get(i)[j];
            }
        }
        Mat response = new Mat(this.label.size(), this.label.get(0).length, CvType.CV_32FC1);
        for (int i = 0; i < this.label.size(); ++i) {
            for (int j = 0; j < this.label.get(0).length; ++j) {
                response.put(i, j, new double[]{this.label.get(i)[j]});
            }
        }
        float[] trf = this.flatten(tr);
        Mat trainData = new Mat(this.train.size(), this.input, CvType.CV_32FC1);
        trainData.put(0, 0, trf);
        this.mlp.train(trainData, 0, response);
        trainData.release();
        response.release();
        this.train.clear();
        this.label.clear();
    }

    float[] predict(float[] i) {
        if (i.length != this.input) {
            throw new IllegalStateException();
        }
        Mat test = new Mat(1, this.input, CvType.CV_32FC1);
        test.put(0, 0, i);
        double val = this.mlp.predict(test, (Mat)this.result, 0);
        return this.getResult();
    }

    float[] getResult() {
        return this.result.toArray();
    }

    float[] flatten(float[][] a) {
        if (a.length == 0) {
            return new float[0];
        }
        int rCnt = a.length;
        int cCnt = a[0].length;
        float[] res = new float[rCnt * cCnt];
        int idx = 0;
        for (int r = 0; r < rCnt; ++r) {
            for (int c = 0; c < cCnt; ++c) {
                res[idx] = a[r][c];
                ++idx;
            }
        }
        return res;
    }

    static {
        NativeLibraryLoader.load();
    }
}

