package org.genericsystem.reinforcer;

import java.io.File;
import java.io.IOException;
import java.util.Random;
import org.datavec.api.conf.Configuration;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.records.reader.impl.FileRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/genericsystem/reinforcer/NNClassifier.class */
public class NNClassifier {
    private static final int seed = 123;
    protected static final Logger log = LoggerFactory.getLogger(NNClassifier.class);
    private static final String[] allowedExtensions = {"txt"};
    public static final Random randNumGen = new Random(123);
    private static final File frModel = new File("frWiki_no_phrase_no_postag_500_cbow_cut10.bin");

    public static void main(String[] strArr) throws Exception {
        InputSplit[] sample = new FileSplit(new File(System.getProperty("user.dir"), "pieces/text"), allowedExtensions, randNumGen).sample(new BalancedPathFilter(randNumGen, allowedExtensions, new ParentPathLabelGenerator(), 0, 0, 20, 0, new String[0]), new double[]{70.0d, 15.0d, 15.0d});
        InputSplit inputSplit = sample[0];
        InputSplit inputSplit2 = sample[1];
        InputSplit inputSplit3 = sample[2];
        Word2Vec readWord2VecModel = WordVectorSerializer.readWord2VecModel(frModel, true);
        DefaultTokenizerFactory defaultTokenizerFactory = new DefaultTokenizerFactory();
        defaultTokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
        VecRecordReader vecRecordReader = new VecRecordReader(readWord2VecModel, defaultTokenizerFactory);
        Configuration configuration = new Configuration();
        configuration.setBoolean(FileRecordReader.APPEND_LABEL, true);
        vecRecordReader.initialize(configuration, inputSplit);
        int size = vecRecordReader.getLabels().size();
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().seed(seed).weightInit(WeightInit.XAVIER).iterations(1).activation(Activation.TANH).learningRate(0.001d).updater(new Nesterovs(0.9d)).gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).regularization(true).l2(1.0E-4d).list().layer(0, new DenseLayer.Builder().nIn(500).nOut(1024).build()).layer(1, new DenseLayer.Builder().nIn(1024).nOut(1024).build()).layer(2, new DenseLayer.Builder().nIn(1024).nOut(1024).build()).layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1024).nOut(size).activation(Activation.SOFTMAX).build()).pretrain(false).backprop(true).build());
        multiLayerNetwork.init();
        multiLayerNetwork.setListeners(new IterationListener[]{new ScoreIterationListener(10)});
        NormalizerStandardize normalizerStandardize = new NormalizerStandardize();
        DataSetIterator dataSetIterator = getDataSetIterator(vecRecordReader, configuration, null, inputSplit, 1, size);
        normalizerStandardize.fit(dataSetIterator);
        dataSetIterator.setPreProcessor(normalizerStandardize);
        new EarlyStoppingTrainer(new EarlyStoppingConfiguration.Builder().epochTerminationConditions(new EpochTerminationCondition[]{new MaxEpochsTerminationCondition(100)}).evaluateEveryNEpochs(1).epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementEpochTerminationCondition(20)}).scoreCalculator(new DataSetLossCalculator(getDataSetIterator(vecRecordReader, configuration, normalizerStandardize, inputSplit2, 1, size), false)).build(), multiLayerNetwork, dataSetIterator).setListener(new EarlyStoppingListener<MultiLayerNetwork>() { // from class: org.genericsystem.reinforcer.NNClassifier.1
            public void onStart(EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration, MultiLayerNetwork multiLayerNetwork2) {
            }

            public void onEpoch(int i, double d, EarlyStoppingConfiguration<MultiLayerNetwork> earlyStoppingConfiguration, MultiLayerNetwork multiLayerNetwork2) {
                NNClassifier.log.info("Epoch {}, score {}.", Integer.valueOf(i), Double.valueOf(d));
            }

            public void onCompletion(EarlyStoppingResult<MultiLayerNetwork> earlyStoppingResult) {
            }

            public /* bridge */ /* synthetic */ void onEpoch(int i, double d, EarlyStoppingConfiguration earlyStoppingConfiguration, Model model) {
                onEpoch(i, d, (EarlyStoppingConfiguration<MultiLayerNetwork>) earlyStoppingConfiguration, (MultiLayerNetwork) model);
            }

            public /* bridge */ /* synthetic */ void onStart(EarlyStoppingConfiguration earlyStoppingConfiguration, Model model) {
                onStart((EarlyStoppingConfiguration<MultiLayerNetwork>) earlyStoppingConfiguration, (MultiLayerNetwork) model);
            }
        });
        log.info("Training without early stopping");
        for (int i = 0; i < 100; i++) {
            multiLayerNetwork.fit(dataSetIterator);
            log.info("Completed epoch {}", Integer.valueOf(i));
            dataSetIterator.reset();
        }
        log.info("Evaluate model....");
        log.info(multiLayerNetwork.evaluate(getDataSetIterator(vecRecordReader, configuration, normalizerStandardize, inputSplit3, 1, size)).stats(true));
    }

    private static DataSetIterator getDataSetIterator(FileRecordReader fileRecordReader, Configuration configuration, DataNormalization dataNormalization, InputSplit inputSplit, int i, int i2) {
        try {
            fileRecordReader.initialize(configuration, inputSplit);
        } catch (IOException | InterruptedException e) {
            log.warn("Impossible to initialize recordReader", e);
        }
        RecordReaderDataSetIterator recordReaderDataSetIterator = new RecordReaderDataSetIterator(fileRecordReader, i, 500, i2);
        if (dataNormalization != null) {
            recordReaderDataSetIterator.setPreProcessor(dataNormalization);
        }
        return recordReaderDataSetIterator;
    }
}
