package smile.validation;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.ToDoubleFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.classification.SoftClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.validation.metric.AUC;
import smile.validation.metric.Accuracy;
import smile.validation.metric.ConfusionMatrix;
import smile.validation.metric.CrossEntropy;
import smile.validation.metric.Error;
import smile.validation.metric.FScore;
import smile.validation.metric.LogLoss;
import smile.validation.metric.MatthewsCorrelation;
import smile.validation.metric.Precision;
import smile.validation.metric.Sensitivity;
import smile.validation.metric.Specificity;

/* loaded from: classes6.dex */
public class ClassificationValidation<M> implements Serializable {
    private static final long serialVersionUID = 2;
    public final ConfusionMatrix confusion;
    public final ClassificationMetrics metrics;
    public final M model;
    public final double[][] posteriori;
    public final int[] prediction;
    public final int[] truth;

    /* JADX WARN: 'this' call moved to the top of the method (can break code semantics) */
    public ClassificationValidation(M m, int[] iArr, int[] iArr2, double d, double d2) {
        this(m, iArr, iArr2, null, d, d2);
    }

    public ClassificationValidation(M m, int[] iArr, int[] iArr2, double[][] dArr, double d, double d2) {
        this.model = m;
        this.truth = iArr;
        this.prediction = iArr2;
        this.posteriori = dArr;
        this.confusion = ConfusionMatrix.of(iArr, iArr2);
        if (MathEx.unique(iArr).length != 2) {
            if (dArr == null) {
                this.metrics = new ClassificationMetrics(d, d2, iArr.length, Error.of(iArr, iArr2), Accuracy.of(iArr, iArr2));
                return;
            } else {
                this.metrics = new ClassificationMetrics(d, d2, iArr.length, Error.of(iArr, iArr2), Accuracy.of(iArr, iArr2), CrossEntropy.of(iArr, dArr));
                return;
            }
        }
        if (dArr == null) {
            this.metrics = new ClassificationMetrics(d, d2, iArr.length, Error.of(iArr, iArr2), Accuracy.of(iArr, iArr2), Sensitivity.of(iArr, iArr2), Specificity.of(iArr, iArr2), Precision.of(iArr, iArr2), FScore.F1.score(iArr, iArr2), MatthewsCorrelation.of(iArr, iArr2));
        } else {
            double[] array = Arrays.stream(dArr).mapToDouble(new ToDoubleFunction() { // from class: smile.validation.ClassificationValidation$$ExternalSyntheticLambda0
                @Override // java.util.function.ToDoubleFunction
                public final double applyAsDouble(Object obj) {
                    return ClassificationValidation.lambda$new$0((double[]) obj);
                }
            }).toArray();
            this.metrics = new ClassificationMetrics(d, d2, iArr.length, Error.of(iArr, iArr2), Accuracy.of(iArr, iArr2), Sensitivity.of(iArr, iArr2), Specificity.of(iArr, iArr2), Precision.of(iArr, iArr2), FScore.F1.score(iArr, iArr2), MatthewsCorrelation.of(iArr, iArr2), AUC.of(iArr, array), LogLoss.of(iArr, array));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ double lambda$new$0(double[] dArr) {
        return dArr[1];
    }

    public static <M extends DataFrameClassifier> ClassificationValidation<M> of(Formula formula, DataFrame dataFrame, DataFrame dataFrame2, BiFunction<Formula, DataFrame, M> biFunction) {
        int[] intArray = formula.y(dataFrame).toIntArray();
        int[] intArray2 = formula.y(dataFrame2).toIntArray();
        int length = MathEx.unique(intArray).length;
        long nanoTime = System.nanoTime();
        DataFrameClassifier dataFrameClassifier = (DataFrameClassifier) biFunction.apply(formula, dataFrame);
        double nanoTime2 = (System.nanoTime() - nanoTime) / 1000000.0d;
        int nrows = dataFrame2.nrows();
        if (!(dataFrameClassifier instanceof SoftClassifier)) {
            long nanoTime3 = System.nanoTime();
            int[] iArr = new int[nrows];
            for (int i = 0; i < nrows; i++) {
                iArr[i] = dataFrameClassifier.predict(dataFrame2.get(i));
            }
            return new ClassificationValidation<>(dataFrameClassifier, intArray2, iArr, nanoTime2, (System.nanoTime() - nanoTime3) / 1000000.0d);
        }
        long nanoTime4 = System.nanoTime();
        int[] iArr2 = new int[nrows];
        double[][] dArr = (double[][]) Array.newInstance((Class<?>) Double.TYPE, nrows, length);
        for (int i2 = 0; i2 < nrows; i2++) {
            iArr2[i2] = ((SoftClassifier) dataFrameClassifier).predict((SoftClassifier) dataFrame2.get(i2), dArr[i2]);
        }
        return new ClassificationValidation<>(dataFrameClassifier, intArray2, iArr2, dArr, nanoTime2, (System.nanoTime() - nanoTime4) / 1000000.0d);
    }

    public static <T, M extends Classifier<T>> ClassificationValidation<M> of(T[] tArr, int[] iArr, T[] tArr2, int[] iArr2, BiFunction<T[], int[], M> biFunction) {
        int length = MathEx.unique(iArr).length;
        long nanoTime = System.nanoTime();
        Classifier classifier = (Classifier) biFunction.apply(tArr, iArr);
        double nanoTime2 = (System.nanoTime() - nanoTime) / 1000000.0d;
        if (!(classifier instanceof SoftClassifier)) {
            return new ClassificationValidation<>(classifier, iArr2, classifier.predict((Object[]) tArr2), nanoTime2, (System.nanoTime() - System.nanoTime()) / 1000000.0d);
        }
        long nanoTime3 = System.nanoTime();
        double[][] dArr = (double[][]) Array.newInstance((Class<?>) Double.TYPE, tArr2.length, length);
        return new ClassificationValidation<>(classifier, iArr2, ((SoftClassifier) classifier).predict(tArr2, dArr), dArr, nanoTime2, (System.nanoTime() - nanoTime3) / 1000000.0d);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> of(Bag[] bagArr, Formula formula, DataFrame dataFrame, BiFunction<Formula, DataFrame, M> biFunction) {
        ArrayList arrayList = new ArrayList(bagArr.length);
        for (Bag bag : bagArr) {
            arrayList.add(of(formula, dataFrame.of(bag.samples), dataFrame.of(bag.oob), biFunction));
        }
        return new ClassificationValidations<>(arrayList);
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> of(Bag[] bagArr, T[] tArr, int[] iArr, BiFunction<T[], int[], M> biFunction) {
        ArrayList arrayList = new ArrayList(bagArr.length);
        for (Bag bag : bagArr) {
            arrayList.add(of(MathEx.slice(tArr, bag.samples), MathEx.slice(iArr, bag.samples), MathEx.slice(tArr, bag.oob), MathEx.slice(iArr, bag.oob), biFunction));
        }
        return new ClassificationValidations<>(arrayList);
    }

    public String toString() {
        return this.metrics.toString();
    }
}
