package smile.classification;

import com.github.mikephil.charting.utils.Utils;
import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Properties;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.LongFunction;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.SplitRule;
import smile.classification.RandomForest;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.feature.TreeSHAP;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.util.Strings;
import smile.validation.ClassificationMetrics;
import smile.validation.metric.AUC;
import smile.validation.metric.Accuracy;
import smile.validation.metric.CrossEntropy;
import smile.validation.metric.Error;
import smile.validation.metric.FScore;
import smile.validation.metric.LogLoss;
import smile.validation.metric.MatthewsCorrelation;
import smile.validation.metric.Precision;
import smile.validation.metric.Sensitivity;
import smile.validation.metric.Specificity;

/* loaded from: classes6.dex */
public class RandomForest implements SoftClassifier<Tuple>, DataFrameClassifier, TreeSHAP {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) RandomForest.class);
    private static final long serialVersionUID = 2;
    private Formula formula;
    private double[] importance;
    private int k;
    private IntSet labels;
    private ClassificationMetrics metrics;
    private Model[] models;

    /* loaded from: classes6.dex */
    public static class Model implements Serializable {
        public final ClassificationMetrics metrics;
        public final DecisionTree tree;
        public final double weight;

        Model(DecisionTree decisionTree, ClassificationMetrics classificationMetrics) {
            this.tree = decisionTree;
            this.metrics = classificationMetrics;
            this.weight = classificationMetrics.accuracy;
        }
    }

    public RandomForest(Formula formula, int i, Model[] modelArr, ClassificationMetrics classificationMetrics, double[] dArr) {
        this(formula, i, modelArr, classificationMetrics, dArr, IntSet.of(i));
    }

    public RandomForest(Formula formula, int i, Model[] modelArr, ClassificationMetrics classificationMetrics, double[] dArr, IntSet intSet) {
        this.k = 2;
        this.formula = formula;
        this.k = i;
        this.models = modelArr;
        this.metrics = classificationMetrics;
        this.importance = dArr;
        this.labels = intSet;
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame, int i, int i2, SplitRule splitRule, int i3, int i4, int i5, double d) {
        return fit(formula, dataFrame, i, i2, splitRule, i3, i4, i5, d, null);
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame, int i, int i2, SplitRule splitRule, int i3, int i4, int i5, double d, int[] iArr) {
        return fit(formula, dataFrame, i, i2, splitRule, i3, i4, i5, d, iArr, null);
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame, int i, int i2, final SplitRule splitRule, final int i3, final int i4, final int i5, final double d, int[] iArr, LongStream longStream) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + i);
        }
        if (d <= Utils.DOUBLE_EPSILON || d > 1.0d) {
            throw new IllegalArgumentException("Invalid sampling rating: " + d);
        }
        Formula expand = formula.expand(dataFrame.schema());
        final DataFrame x = expand.x(dataFrame);
        final BaseVector y = expand.y(dataFrame);
        if (i2 > x.ncols()) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + i2);
        }
        final int sqrt = i2 > 0 ? i2 : (int) Math.sqrt(x.ncols());
        final ClassLabels fit = ClassLabels.fit(y);
        final int i6 = fit.k;
        final int nrows = x.nrows();
        final int[] array = iArr != null ? iArr : Collections.nCopies(i6, 1).stream().mapToInt(new ToIntFunction() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda2
            @Override // java.util.function.ToIntFunction
            public final int applyAsInt(Object obj) {
                int intValue;
                intValue = ((Integer) obj).intValue();
                return intValue;
            }
        }).toArray();
        final int[][] order = CART.order(x);
        final int[][] iArr2 = (int[][]) Array.newInstance((Class<?>) Integer.TYPE, nrows, i6);
        long[] array2 = (longStream != null ? longStream : LongStream.range(-i, 0L)).sequential().distinct().limit(i).toArray();
        if (array2.length != i) {
            throw new IllegalArgumentException(String.format("seed stream has only %d distinct values, expected %d", Integer.valueOf(array2.length), Integer.valueOf(i)));
        }
        final int[] iArr3 = new int[i6];
        for (int i7 = 0; i7 < nrows; i7++) {
            int i8 = fit.y[i7];
            iArr3[i8] = iArr3[i8] + 1;
        }
        final int[][] iArr4 = new int[i6];
        for (int i9 = 0; i9 < i6; i9++) {
            iArr4[i9] = new int[iArr3[i9]];
        }
        int[] iArr5 = new int[i6];
        for (int i10 = 0; i10 < nrows; i10++) {
            int i11 = fit.y[i10];
            int[] iArr6 = iArr4[i11];
            int i12 = iArr5[i11];
            iArr5[i11] = i12 + 1;
            iArr6[i12] = i10;
        }
        Model[] modelArr = (Model[]) Arrays.stream(array2).parallel().mapToObj(new LongFunction() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda3
            @Override // java.util.function.LongFunction
            public final Object apply(long j) {
                return RandomForest.lambda$fit$2(nrows, d, i6, iArr3, array, iArr4, x, fit, y, splitRule, i3, i4, i5, sqrt, order, iArr2, j);
            }
        }).toArray(new IntFunction() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda4
            @Override // java.util.function.IntFunction
            public final Object apply(int i13) {
                return RandomForest.lambda$fit$3(i13);
            }
        });
        double d2 = Utils.DOUBLE_EPSILON;
        double d3 = Utils.DOUBLE_EPSILON;
        for (Model model : modelArr) {
            d2 += model.metrics.fitTime;
            d3 += model.metrics.scoreTime;
        }
        int[] iArr7 = new int[nrows];
        for (int i13 = 0; i13 < nrows; i13++) {
            iArr7[i13] = MathEx.whichMax(iArr2[i13]);
        }
        return new RandomForest(expand, i6, modelArr, new ClassificationMetrics(d2, d3, nrows, Error.of(fit.y, iArr7), Accuracy.of(fit.y, iArr7)), importance(modelArr), fit.labels);
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Integer.valueOf(properties.getProperty("smile.random.forest.trees", "500")).intValue(), Integer.valueOf(properties.getProperty("smile.random.forest.mtry", "0")).intValue(), SplitRule.valueOf(properties.getProperty("smile.random.forest.split.rule", "GINI")), Integer.valueOf(properties.getProperty("smile.random.forest.max.depth", "20")).intValue(), Integer.valueOf(properties.getProperty("smile.random.forest.max.nodes", String.valueOf(dataFrame.size() / 5))).intValue(), Integer.valueOf(properties.getProperty("smile.random.forest.node.size", "5")).intValue(), Double.valueOf(properties.getProperty("smile.random.forest.sample.rate", "1.0")).doubleValue(), Strings.parseIntArray(properties.getProperty("smile.random.forest.class.weight")), null);
    }

    private static double[] importance(Model[] modelArr) {
        int length = modelArr[0].tree.importance().length;
        double[] dArr = new double[length];
        for (Model model : modelArr) {
            double[] importance = model.tree.importance();
            for (int i = 0; i < length; i++) {
                dArr[i] = dArr[i] + importance[i];
            }
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Model lambda$fit$2(int i, double d, int i2, int[] iArr, int[] iArr2, int[][] iArr3, DataFrame dataFrame, ClassLabels classLabels, BaseVector baseVector, SplitRule splitRule, int i3, int i4, int i5, int i6, int[][] iArr4, int[][] iArr5, long j) {
        ClassificationMetrics classificationMetrics;
        int i7 = i;
        if (j > 1) {
            MathEx.setSeed(j);
        }
        int[] iArr6 = new int[i7];
        if (d == 1.0d) {
            for (int i8 = 0; i8 < i2; i8++) {
                int i9 = iArr[i8];
                int i10 = i9 / iArr2[i8];
                int[] iArr7 = iArr3[i8];
                for (int i11 = 0; i11 < i10; i11++) {
                    int i12 = iArr7[MathEx.randomInt(i9)];
                    iArr6[i12] = iArr6[i12] + 1;
                }
            }
        } else {
            for (int i13 = 0; i13 < i2; i13++) {
                int round = (int) Math.round((iArr[i13] * d) / iArr2[i13]);
                int[] iArr8 = iArr3[i13];
                int[] permutate = MathEx.permutate(iArr[i13]);
                for (int i14 = 0; i14 < round; i14++) {
                    int i15 = iArr8[permutate[i14]];
                    iArr6[i15] = iArr6[i15] + 1;
                }
            }
        }
        long nanoTime = System.nanoTime();
        DecisionTree decisionTree = new DecisionTree(dataFrame, classLabels.y, baseVector.field(), i2, splitRule, i3, i4, i5, i6, iArr6, iArr4);
        double nanoTime2 = (System.nanoTime() - nanoTime) / 1000000.0d;
        long nanoTime3 = System.nanoTime();
        int i16 = 0;
        for (int i17 = 0; i17 < i7; i17++) {
            if (iArr6[i17] == 0) {
                i16++;
            }
        }
        int[] iArr9 = new int[i16];
        int[] iArr10 = new int[i16];
        double[][] dArr = (double[][]) Array.newInstance((Class<?>) Double.TYPE, i16, i2);
        int i18 = 0;
        int i19 = 0;
        while (i18 < i7) {
            if (iArr6[i18] == 0) {
                iArr9[i19] = classLabels.y[i18];
                int predict = decisionTree.predict(dataFrame.get(i18), dArr[i19]);
                iArr10[i19] = predict;
                int[] iArr11 = iArr5[i18];
                iArr11[predict] = iArr11[predict] + 1;
                i19++;
            }
            i18++;
            i7 = i;
        }
        double nanoTime4 = (System.nanoTime() - nanoTime3) / 1000000.0d;
        if (MathEx.unique(iArr9).length == 2) {
            double[] array = Arrays.stream(dArr).mapToDouble(new ToDoubleFunction() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda6
                @Override // java.util.function.ToDoubleFunction
                public final double applyAsDouble(Object obj) {
                    return RandomForest.lambda$null$1((double[]) obj);
                }
            }).toArray();
            classificationMetrics = new ClassificationMetrics(nanoTime2, nanoTime4, i16, Error.of(iArr9, iArr10), Accuracy.of(iArr9, iArr10), Sensitivity.of(iArr9, iArr10), Specificity.of(iArr9, iArr10), Precision.of(iArr9, iArr10), FScore.F1.score(iArr9, iArr10), MatthewsCorrelation.of(iArr9, iArr10), AUC.of(iArr9, array), LogLoss.of(iArr9, array));
        } else {
            classificationMetrics = new ClassificationMetrics(nanoTime2, nanoTime4, i16, Error.of(iArr9, iArr10), Accuracy.of(iArr9, iArr10), CrossEntropy.of(iArr9, dArr));
        }
        if (i16 != 0) {
            logger.info("Random forest tree OOB metrics: {}", classificationMetrics);
        } else {
            logger.error("Random forest has a tree trained without OOB samples.");
        }
        return new Model(decisionTree, classificationMetrics);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Model[] lambda$fit$3(int i) {
        return new Model[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ double lambda$null$1(double[] dArr) {
        return dArr[1];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Model[] lambda$prune$8(int i) {
        return new Model[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ DecisionTree[] lambda$trees$5(int i) {
        return new DecisionTree[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ double lambda$trim$6(Model model) {
        return -model.weight;
    }

    @Override // smile.classification.DataFrameClassifier, smile.feature.TreeSHAP
    public Formula formula() {
        return this.formula;
    }

    public double[] importance() {
        return this.importance;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: lambda$prune$7$smile-classification-RandomForest, reason: not valid java name */
    public /* synthetic */ Model m6754lambda$prune$7$smileclassificationRandomForest(DataFrame dataFrame, Model model) {
        return new Model(model.tree.prune(dataFrame, this.formula, this.labels), model.metrics);
    }

    public RandomForest merge(RandomForest randomForest) {
        if (!this.formula.equals(randomForest.formula)) {
            throw new IllegalArgumentException("RandomForest have different model formula");
        }
        Model[] modelArr = new Model[this.models.length + randomForest.models.length];
        System.arraycopy(this.models, 0, modelArr, 0, this.models.length);
        System.arraycopy(randomForest.models, 0, modelArr, this.models.length, randomForest.models.length);
        ClassificationMetrics classificationMetrics = new ClassificationMetrics(randomForest.metrics.fitTime * this.metrics.fitTime, randomForest.metrics.scoreTime * this.metrics.scoreTime, this.metrics.size, (this.metrics.error * randomForest.metrics.error) / 2, (this.metrics.accuracy * randomForest.metrics.accuracy) / 2.0d, (this.metrics.sensitivity * randomForest.metrics.sensitivity) / 2.0d, (this.metrics.specificity * randomForest.metrics.specificity) / 2.0d, (this.metrics.precision * randomForest.metrics.precision) / 2.0d, (this.metrics.f1 * randomForest.metrics.f1) / 2.0d, (this.metrics.mcc * randomForest.metrics.mcc) / 2.0d, (this.metrics.auc * randomForest.metrics.auc) / 2.0d, (this.metrics.logloss * randomForest.metrics.logloss) / 2.0d, (this.metrics.crossentropy * randomForest.metrics.crossentropy) / 2.0d);
        double[] dArr = (double[]) this.importance.clone();
        for (int i = 0; i < this.importance.length; i++) {
            dArr[i] = dArr[i] + randomForest.importance[i];
        }
        return new RandomForest(this.formula, this.k, modelArr, classificationMetrics, dArr, this.labels);
    }

    public ClassificationMetrics metrics() {
        return this.metrics;
    }

    public Model[] models() {
        return this.models;
    }

    @Override // smile.classification.Classifier
    public int predict(Tuple tuple) {
        Tuple x = this.formula.x(tuple);
        int[] iArr = new int[this.k];
        for (Model model : this.models) {
            int predict = model.tree.predict(x);
            iArr[predict] = iArr[predict] + 1;
        }
        return this.labels.valueOf(MathEx.whichMax(iArr));
    }

    @Override // smile.classification.SoftClassifier
    public int predict(Tuple tuple, double[] dArr) {
        if (dArr.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.k)));
        }
        Tuple x = this.formula.x(tuple);
        double[] dArr2 = new double[this.k];
        Arrays.fill(dArr, Utils.DOUBLE_EPSILON);
        for (Model model : this.models) {
            model.tree.predict(x, dArr2);
            for (int i = 0; i < this.k; i++) {
                dArr[i] = dArr[i] + (model.weight * dArr2[i]);
            }
        }
        MathEx.unitize1(dArr);
        return this.labels.valueOf(MathEx.whichMax(dArr));
    }

    public RandomForest prune(final DataFrame dataFrame) {
        Model[] modelArr = (Model[]) ((Stream) Arrays.stream(this.models).parallel()).map(new Function() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda0
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                return RandomForest.this.m6754lambda$prune$7$smileclassificationRandomForest(dataFrame, (RandomForest.Model) obj);
            }
        }).toArray(new IntFunction() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda1
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return RandomForest.lambda$prune$8(i);
            }
        });
        return new RandomForest(this.formula, this.k, modelArr, this.metrics, importance(modelArr), this.labels);
    }

    @Override // smile.classification.DataFrameClassifier
    public StructType schema() {
        return this.models[0].tree.schema();
    }

    public int size() {
        return this.models.length;
    }

    public int[][] test(DataFrame dataFrame) {
        DataFrame x = this.formula.x(dataFrame);
        int size = x.size();
        int length = this.models.length;
        int[] iArr = new int[this.k];
        int[][] iArr2 = (int[][]) Array.newInstance((Class<?>) Integer.TYPE, length, size);
        for (int i = 0; i < size; i++) {
            Tuple tuple = x.get(i);
            Arrays.fill(iArr, 0);
            for (int i2 = 0; i2 < length; i2++) {
                int predict = this.models[i2].tree.predict(tuple);
                iArr[predict] = iArr[predict] + 1;
                iArr2[i2][i] = MathEx.whichMax(iArr);
            }
        }
        return iArr2;
    }

    @Override // smile.feature.TreeSHAP
    public DecisionTree[] trees() {
        return (DecisionTree[]) Arrays.stream(this.models).map(new Function() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda7
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                DecisionTree decisionTree;
                decisionTree = ((RandomForest.Model) obj).tree;
                return decisionTree;
            }
        }).toArray(new IntFunction() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda8
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return RandomForest.lambda$trees$5(i);
            }
        });
    }

    public RandomForest trim(int i) {
        if (i > this.models.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        Arrays.sort(this.models, Comparator.comparingDouble(new ToDoubleFunction() { // from class: smile.classification.RandomForest$$ExternalSyntheticLambda5
            @Override // java.util.function.ToDoubleFunction
            public final double applyAsDouble(Object obj) {
                return RandomForest.lambda$trim$6((RandomForest.Model) obj);
            }
        }));
        return new RandomForest(this.formula, this.k, (Model[]) Arrays.copyOf(this.models, i), this.metrics, importance(this.models), this.labels);
    }

    public int vote(Tuple tuple, double[] dArr) {
        if (dArr.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.k)));
        }
        Tuple x = this.formula.x(tuple);
        Arrays.fill(dArr, Utils.DOUBLE_EPSILON);
        for (Model model : this.models) {
            int predict = model.tree.predict(x);
            dArr[predict] = dArr[predict] + 1.0d;
        }
        MathEx.unitize1(dArr);
        return this.labels.valueOf(MathEx.whichMax(dArr));
    }
}
