package org.genericsystem.reinforcer;

import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Date;
import org.apache.spark.SparkConf;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LinearSVC;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/genericsystem/reinforcer/SVMClassifier.class */
public class SVMClassifier {
    private static final Logger logger = LoggerFactory.getLogger(SVMClassifier.class);

    public static void main(String[] strArr) {
        SparkSession orCreate = SparkSession.builder().config(new SparkConf().setMaster("local[*]")).appName("SVMClassifier").getOrCreate();
        new SVMClassifier().trainModel(orCreate);
        orCreate.stop();
    }

    public Dataset<Row> loadData(SparkSession sparkSession, String str) {
        Dataset withColumn = sparkSession.read().text(str).withColumn("label", functions.input_file_name());
        return withColumn.withColumn("label", functions.regexp_extract(withColumn.col("label"), ".*/([^/]*)/[^/]*", 1)).cache();
    }

    public void trainModel(SparkSession sparkSession) {
        Dataset[] randomSplit = loadData(sparkSession, "pieces/text/*").randomSplit(new double[]{0.7d, 0.3d});
        Dataset dataset = randomSplit[0];
        Dataset dataset2 = randomSplit[1];
        PipelineStage handleInvalid = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").setHandleInvalid("keep");
        PipelineStage pipelineStage = (Tokenizer) new Tokenizer().setInputCol("value").setOutputCol("words");
        PipelineStage numFeatures = new HashingTF().setInputCol(pipelineStage.getOutputCol()).setOutputCol("rawFeatures").setNumFeatures(20);
        PipelineStage outputCol = new IDF().setInputCol(numFeatures.getOutputCol()).setOutputCol("features");
        LinearSVC regParam = new LinearSVC().setMaxIter(100).setRegParam(0.1d);
        Pipeline stages = new Pipeline().setStages(new PipelineStage[]{handleInvalid, pipelineStage, numFeatures, outputCol, new OneVsRest().setClassifier(regParam).setFeaturesCol(outputCol.getOutputCol()).setLabelCol(handleInvalid.getOutputCol()), new IndexToString().setLabels(handleInvalid.fit(dataset).labels()).setInputCol("prediction").setOutputCol("origPrediction")});
        ParamMap[] build = new ParamGridBuilder().addGrid(numFeatures.numFeatures(), new int[]{20, 50, 200, 1000}).addGrid(regParam.maxIter(), new int[]{20, 50, 100, 200}).addGrid(regParam.regParam(), new double[]{0.05d, 0.1d}).build();
        MulticlassClassificationEvaluator labelCol = new MulticlassClassificationEvaluator().setMetricName("accuracy").setLabelCol(handleInvalid.getOutputCol());
        CrossValidatorModel fit = new CrossValidator().setEstimator(stages).setEvaluator(labelCol).setEstimatorParamMaps(build).setNumFolds(3).fit(dataset);
        logger.info("Test error = {}.", Double.valueOf(1.0d - labelCol.evaluate(fit.transform(dataset2))));
        try {
            fit.save("SVMModel-" + new SimpleDateFormat("yyyyMMddHHmmss").format(new Date()));
        } catch (IOException e) {
            throw new RuntimeException("Exception while saving trained model", e);
        }
    }

    public void testModel(SparkSession sparkSession, CrossValidatorModel crossValidatorModel, String str) {
        crossValidatorModel.transform(loadData(sparkSession, str)).select("labelIndex", new String[]{"label", "prediction"}).show();
    }
}
