package org.genericsystem.reinforcer.tools;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/* loaded from: input_file:org/genericsystem/reinforcer/tools/CosineSimilarity.class */
public class CosineSimilarity {

    /* loaded from: input_file:org/genericsystem/reinforcer/tools/CosineSimilarity$PATTERN.class */
    public enum PATTERN {
        SINGLE_CHAR(Pattern.compile("(?!^)")),
        WORDS(Pattern.compile("\\W+")),
        SPACE(Pattern.compile("\\S+"));

        private Pattern pattern;

        PATTERN(Pattern pattern) {
            this.pattern = pattern;
        }

        public Pattern getPattern() {
            return this.pattern;
        }
    }

    public static void main(String[] strArr) {
        System.out.println(cosineSimilarity("bob", "rob", PATTERN.SINGLE_CHAR));
        System.out.println(cosineSimilarity("hello", "molehill", PATTERN.SINGLE_CHAR));
    }

    public static double cosineSimilarity(String str, String str2, PATTERN pattern) {
        if (null == str || null == str2) {
            throw new IllegalArgumentException("Cosine similarity requires two not null strings");
        }
        if (str.equals(str2)) {
            return 1.0d;
        }
        if (str.isEmpty() || str2.isEmpty()) {
            return RectangleTools.DEFAULT_GROUP_THRESHOLD;
        }
        Map<String, Long> frequencyMap = getFrequencyMap(str, pattern);
        Map<String, Long> frequencyMap2 = getFrequencyMap(str2, pattern);
        HashSet hashSet = new HashSet(frequencyMap.keySet());
        hashSet.retainAll(frequencyMap2.keySet());
        return hashSet.isEmpty() ? RectangleTools.DEFAULT_GROUP_THRESHOLD : hashSet.stream().mapToDouble(str3 -> {
            return ((Long) frequencyMap.get(str3)).longValue() * ((Long) frequencyMap2.get(str3)).longValue();
        }).sum() / Math.sqrt(frequencyMap.values().stream().mapToDouble(l -> {
            return Math.pow(l.longValue(), 2.0d);
        }).sum() * frequencyMap2.values().stream().mapToDouble(l2 -> {
            return Math.pow(l2.longValue(), 2.0d);
        }).sum());
    }

    public static Map<String, Long> getFrequencyMap(String str, PATTERN pattern) {
        return (Map) Arrays.asList(pattern.getPattern().split(str.trim())).stream().collect(Collectors.groupingBy(str2 -> {
            return str2;
        }, Collectors.counting()));
    }
}
