/*
 * Decompiled with CFR 0.152.
 */
package org.genericsystem.reinforcer;

import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Date;
import org.apache.spark.SparkConf;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.LinearSVC;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SVMClassifier {
    private static final Logger logger = LoggerFactory.getLogger(SVMClassifier.class);

    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf().setMaster("local[*]");
        SparkSession spark = SparkSession.builder().config(sparkConf).appName("SVMClassifier").getOrCreate();
        new SVMClassifier().trainModel(spark);
        spark.stop();
    }

    public Dataset<Row> loadData(SparkSession spark, String path) {
        Dataset data = spark.read().text(path).withColumn("label", functions.input_file_name());
        return data.withColumn("label", functions.regexp_extract((Column)data.col("label"), (String)".*/([^/]*)/[^/]*", (int)1)).cache();
    }

    public void trainModel(SparkSession spark) {
        Dataset<Row> data = this.loadData(spark, "pieces/text/*");
        Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3});
        Dataset trainData = splits[0];
        Dataset validData = splits[1];
        StringIndexer indexer = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").setHandleInvalid("keep");
        Tokenizer tokenizer = (Tokenizer)((Tokenizer)new Tokenizer().setInputCol("value")).setOutputCol("words");
        HashingTF hashingTF = new HashingTF().setInputCol(tokenizer.getOutputCol()).setOutputCol("rawFeatures").setNumFeatures(20);
        IDF idf = new IDF().setInputCol(hashingTF.getOutputCol()).setOutputCol("features");
        LinearSVC lsvc = new LinearSVC().setMaxIter(100).setRegParam(0.1);
        OneVsRest ovr = new OneVsRest().setClassifier((Classifier)lsvc).setFeaturesCol(idf.getOutputCol()).setLabelCol(indexer.getOutputCol());
        IndexToString converter = new IndexToString().setLabels(indexer.fit(trainData).labels()).setInputCol("prediction").setOutputCol("origPrediction");
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{indexer, tokenizer, hashingTF, idf, ovr, converter});
        ParamMap[] paramGrid = new ParamGridBuilder().addGrid(hashingTF.numFeatures(), new int[]{20, 50, 200, 1000}).addGrid(lsvc.maxIter(), new int[]{20, 50, 100, 200}).addGrid(lsvc.regParam(), new double[]{0.05, 0.1}).build();
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy").setLabelCol(indexer.getOutputCol());
        CrossValidator cv = new CrossValidator().setEstimator((Estimator)pipeline).setEvaluator((Evaluator)evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(3);
        CrossValidatorModel cvModel = cv.fit(trainData);
        Dataset predictions = cvModel.transform(validData);
        double accuracy = evaluator.evaluate(predictions);
        logger.info("Test error = {}.", (Object)(1.0 - accuracy));
        SimpleDateFormat format = new SimpleDateFormat("yyyyMMddHHmmss");
        try {
            cvModel.save("SVMModel-" + format.format(new Date()));
        }
        catch (IOException e) {
            throw new RuntimeException("Exception while saving trained model", e);
        }
    }

    public void testModel(SparkSession spark, CrossValidatorModel model, String file) {
        Dataset<Row> data = this.loadData(spark, file);
        Dataset prediction = model.transform(data);
        prediction.select("labelIndex", new String[]{"label", "prediction"}).show();
    }
}

