package smile.regression;

import com.github.mikephil.charting.utils.Utils;
import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Properties;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.LongFunction;
import java.util.function.ToDoubleFunction;
import java.util.stream.LongStream;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.feature.TreeSHAP;
import smile.math.MathEx;
import smile.regression.RandomForest;
import smile.validation.RegressionMetrics;
import smile.validation.metric.MAD;
import smile.validation.metric.MSE;
import smile.validation.metric.R2;
import smile.validation.metric.RMSE;
import smile.validation.metric.RSS;

/* loaded from: classes6.dex */
public class RandomForest implements Regression<Tuple>, DataFrameRegression, TreeSHAP {
    private static final long serialVersionUID = 2;
    private Formula formula;
    private double[] importance;
    private RegressionMetrics metrics;
    private Model[] models;

    /* loaded from: classes6.dex */
    public static class Model implements Serializable {
        public final RegressionMetrics metrics;
        public final RegressionTree tree;

        Model(RegressionTree regressionTree, RegressionMetrics regressionMetrics) {
            this.tree = regressionTree;
            this.metrics = regressionMetrics;
        }
    }

    public RandomForest(Formula formula, Model[] modelArr, RegressionMetrics regressionMetrics, double[] dArr) {
        this.formula = formula;
        this.models = modelArr;
        this.metrics = regressionMetrics;
        this.importance = dArr;
    }

    private static double[] calculateImportance(Model[] modelArr) {
        double[] dArr = new double[modelArr[0].tree.importance().length];
        for (Model model : modelArr) {
            double[] importance = model.tree.importance();
            for (int i = 0; i < importance.length; i++) {
                dArr[i] = dArr[i] + importance[i];
            }
        }
        return dArr;
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame, int i, int i2, int i3, int i4, int i5, double d) {
        return fit(formula, dataFrame, i, i2, i3, i4, i5, d, null);
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame, int i, int i2, final int i3, final int i4, final int i5, final double d, LongStream longStream) {
        double[] dArr;
        if (i < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + i);
        }
        if (d <= Utils.DOUBLE_EPSILON || d > 1.0d) {
            throw new IllegalArgumentException("Invalid sampling rate: " + d);
        }
        Formula expand = formula.expand(dataFrame.schema());
        final DataFrame x = expand.x(dataFrame);
        BaseVector y = expand.y(dataFrame);
        final StructField field = y.field();
        final double[] doubleArray = y.toDoubleArray();
        if (i2 > x.ncols()) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + i2);
        }
        final int max = i2 > 0 ? i2 : Math.max(x.ncols() / 3, 1);
        final int nrows = x.nrows();
        double[] dArr2 = new double[nrows];
        final int[] iArr = new int[nrows];
        final int[][] order = CART.order(x);
        long[] array = (longStream != null ? longStream : LongStream.range(-i, 0L)).sequential().distinct().limit(i).toArray();
        if (array.length != i) {
            throw new IllegalArgumentException(String.format("seed stream has only %d distinct values, expected %d", Integer.valueOf(array.length), Integer.valueOf(i)));
        }
        final double[] dArr3 = dArr2;
        Model[] modelArr = (Model[]) Arrays.stream(array).parallel().mapToObj(new LongFunction() { // from class: smile.regression.RandomForest$$ExternalSyntheticLambda0
            @Override // java.util.function.LongFunction
            public final Object apply(long j) {
                return RandomForest.lambda$fit$0(nrows, d, x, doubleArray, field, i3, i4, i5, max, order, iArr, dArr3, j);
            }
        }).toArray(new IntFunction() { // from class: smile.regression.RandomForest$$ExternalSyntheticLambda1
            @Override // java.util.function.IntFunction
            public final Object apply(int i6) {
                return RandomForest.lambda$fit$1(i6);
            }
        });
        double d2 = Utils.DOUBLE_EPSILON;
        double d3 = Utils.DOUBLE_EPSILON;
        for (Model model : modelArr) {
            d2 += model.metrics.fitTime;
            d3 += model.metrics.scoreTime;
        }
        int i6 = 0;
        while (i6 < nrows) {
            if (iArr[i6] > 0) {
                dArr = dArr3;
                dArr[i6] = dArr[i6] / iArr[i6];
            } else {
                dArr = dArr3;
            }
            i6++;
            dArr3 = dArr;
        }
        double[] dArr4 = dArr3;
        return new RandomForest(expand, modelArr, new RegressionMetrics(d2, d3, nrows, RSS.of(doubleArray, dArr4), MSE.of(doubleArray, dArr4), RMSE.of(doubleArray, dArr4), MAD.of(doubleArray, dArr4), R2.of(doubleArray, dArr4)), calculateImportance(modelArr));
    }

    public static RandomForest fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Integer.valueOf(properties.getProperty("smile.random.forest.trees", "500")).intValue(), Integer.valueOf(properties.getProperty("smile.random.forest.mtry", "0")).intValue(), Integer.valueOf(properties.getProperty("smile.random.forest.max.depth", "20")).intValue(), Integer.valueOf(properties.getProperty("smile.random.forest.max.nodes", String.valueOf(dataFrame.size() / 5))).intValue(), Integer.valueOf(properties.getProperty("smile.random.forest.node.size", "5")).intValue(), Double.valueOf(properties.getProperty("smile.random.forest.sample.rate", "1.0")).doubleValue());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Model lambda$fit$0(int i, double d, DataFrame dataFrame, double[] dArr, StructField structField, int i2, int i3, int i4, int i5, int[][] iArr, int[] iArr2, double[] dArr2, long j) {
        if (j > 1) {
            MathEx.setSeed(j);
        }
        int[] iArr3 = new int[i];
        if (d == 1.0d) {
            for (int i6 = 0; i6 < i; i6++) {
                int randomInt = MathEx.randomInt(i);
                iArr3[randomInt] = iArr3[randomInt] + 1;
            }
        } else {
            int[] permutate = MathEx.permutate(i);
            int round = (int) Math.round(i * d);
            for (int i7 = 0; i7 < round; i7++) {
                iArr3[permutate[i7]] = 1;
            }
        }
        long nanoTime = System.nanoTime();
        RegressionTree regressionTree = new RegressionTree(dataFrame, Loss.ls(dArr), structField, i2, i3, i4, i5, iArr3, iArr);
        double nanoTime2 = (System.nanoTime() - nanoTime) / 1000000.0d;
        long nanoTime3 = System.nanoTime();
        int i8 = 0;
        for (int i9 = 0; i9 < i; i9++) {
            if (iArr3[i9] == 0) {
                i8++;
            }
        }
        double[] dArr3 = new double[i8];
        double[] dArr4 = new double[i8];
        int i10 = 0;
        for (int i11 = 0; i11 < i; i11++) {
            if (iArr3[i11] == 0) {
                dArr3[i10] = dArr[i11];
                double predict = regressionTree.predict(dataFrame.get(i11));
                dArr4[i10] = predict;
                iArr2[i11] = iArr2[i11] + 1;
                dArr2[i11] = dArr2[i11] + predict;
                i10++;
            }
        }
        return new Model(regressionTree, new RegressionMetrics(nanoTime2, (System.nanoTime() - nanoTime3) / 1000000.0d, i8, RSS.of(dArr3, dArr4), MSE.of(dArr3, dArr4), RMSE.of(dArr3, dArr4), MAD.of(dArr3, dArr4), R2.of(dArr3, dArr4)));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Model[] lambda$fit$1(int i) {
        return new Model[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ RegressionTree[] lambda$trees$3(int i) {
        return new RegressionTree[i];
    }

    @Override // smile.regression.DataFrameRegression, smile.feature.TreeSHAP
    public Formula formula() {
        return this.formula;
    }

    public double[] importance() {
        return this.importance;
    }

    public RandomForest merge(RandomForest randomForest) {
        if (!this.formula.equals(randomForest.formula)) {
            throw new IllegalArgumentException("RandomForest have different model formula");
        }
        Model[] modelArr = new Model[this.models.length + randomForest.models.length];
        System.arraycopy(this.models, 0, modelArr, 0, this.models.length);
        System.arraycopy(randomForest.models, 0, modelArr, this.models.length, randomForest.models.length);
        RegressionMetrics regressionMetrics = new RegressionMetrics(this.metrics.fitTime * randomForest.metrics.fitTime, this.metrics.scoreTime * randomForest.metrics.scoreTime, this.metrics.size, (this.metrics.rss * randomForest.metrics.rss) / 2.0d, (this.metrics.mse * randomForest.metrics.mse) / 2.0d, (this.metrics.rmse * randomForest.metrics.rmse) / 2.0d, (this.metrics.mad * randomForest.metrics.mad) / 2.0d, (this.metrics.r2 * randomForest.metrics.r2) / 2.0d);
        double[] dArr = (double[]) this.importance.clone();
        for (int i = 0; i < this.importance.length; i++) {
            dArr[i] = dArr[i] + randomForest.importance[i];
        }
        return new RandomForest(this.formula, modelArr, regressionMetrics, dArr);
    }

    public RegressionMetrics metrics() {
        return this.metrics;
    }

    public Model[] models() {
        return this.models;
    }

    @Override // smile.regression.Regression
    public double predict(Tuple tuple) {
        Tuple x = this.formula.x(tuple);
        double d = Utils.DOUBLE_EPSILON;
        for (Model model : this.models) {
            d += model.tree.predict(x);
        }
        return d / this.models.length;
    }

    @Override // smile.regression.DataFrameRegression
    public StructType schema() {
        return this.models[0].tree.schema();
    }

    public int size() {
        return this.models.length;
    }

    public double[][] test(DataFrame dataFrame) {
        DataFrame x = this.formula.x(dataFrame);
        int nrows = x.nrows();
        int length = this.models.length;
        double[][] dArr = (double[][]) Array.newInstance((Class<?>) Double.TYPE, length, nrows);
        for (int i = 0; i < nrows; i++) {
            Tuple tuple = x.get(i);
            double d = Utils.DOUBLE_EPSILON;
            for (int i2 = 0; i2 < length; i2++) {
                d += this.models[i2].tree.predict(tuple);
                dArr[i2][i] = d / (i2 + 1);
            }
        }
        return dArr;
    }

    @Override // smile.feature.TreeSHAP
    public RegressionTree[] trees() {
        return (RegressionTree[]) Arrays.stream(this.models).map(new Function() { // from class: smile.regression.RandomForest$$ExternalSyntheticLambda3
            @Override // java.util.function.Function
            public final Object apply(Object obj) {
                RegressionTree regressionTree;
                regressionTree = ((RandomForest.Model) obj).tree;
                return regressionTree;
            }
        }).toArray(new IntFunction() { // from class: smile.regression.RandomForest$$ExternalSyntheticLambda4
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return RandomForest.lambda$trees$3(i);
            }
        });
    }

    public RandomForest trim(int i) {
        if (i > this.models.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        Arrays.sort(this.models, Comparator.comparingDouble(new ToDoubleFunction() { // from class: smile.regression.RandomForest$$ExternalSyntheticLambda2
            @Override // java.util.function.ToDoubleFunction
            public final double applyAsDouble(Object obj) {
                double d;
                d = ((RandomForest.Model) obj).metrics.rmse;
                return d;
            }
        }));
        return new RandomForest(this.formula, (Model[]) Arrays.copyOf(this.models, i), this.metrics, this.importance);
    }
}
