package smile.validation;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.function.BiFunction;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.validation.metric.MAD;
import smile.validation.metric.MSE;
import smile.validation.metric.R2;
import smile.validation.metric.RMSE;
import smile.validation.metric.RSS;

/* loaded from: classes6.dex */
public class RegressionValidation<M> implements Serializable {
    private static final long serialVersionUID = 2;
    public final RegressionMetrics metrics;
    public final M model;
    public final double[] prediction;
    public final double[] truth;

    public RegressionValidation(M m, double[] dArr, double[] dArr2, RegressionMetrics regressionMetrics) {
        this.model = m;
        this.truth = dArr;
        this.prediction = dArr2;
        this.metrics = regressionMetrics;
    }

    public static <M extends DataFrameRegression> RegressionValidation<M> of(Formula formula, DataFrame dataFrame, DataFrame dataFrame2, BiFunction<Formula, DataFrame, M> biFunction) {
        double[] doubleArray = formula.y(dataFrame2).toDoubleArray();
        long nanoTime = System.nanoTime();
        DataFrameRegression dataFrameRegression = (DataFrameRegression) biFunction.apply(formula, dataFrame);
        double nanoTime2 = (System.nanoTime() - nanoTime) / 1000000.0d;
        long nanoTime3 = System.nanoTime();
        int nrows = dataFrame2.nrows();
        double[] dArr = new double[nrows];
        for (int i = 0; i < nrows; i++) {
            dArr[i] = dataFrameRegression.predict(dataFrame2.get(i));
        }
        return new RegressionValidation<>(dataFrameRegression, doubleArray, dArr, new RegressionMetrics(nanoTime2, (System.nanoTime() - nanoTime3) / 1000000.0d, doubleArray.length, RSS.of(doubleArray, dArr), MSE.of(doubleArray, dArr), RMSE.of(doubleArray, dArr), MAD.of(doubleArray, dArr), R2.of(doubleArray, dArr)));
    }

    public static <T, M extends Regression<T>> RegressionValidation<M> of(T[] tArr, double[] dArr, T[] tArr2, double[] dArr2, BiFunction<T[], double[], M> biFunction) {
        long nanoTime = System.nanoTime();
        Regression regression = (Regression) biFunction.apply(tArr, dArr);
        double nanoTime2 = (System.nanoTime() - nanoTime) / 1000000.0d;
        long nanoTime3 = System.nanoTime();
        double[] predict = regression.predict((Object[]) tArr2);
        return new RegressionValidation<>(regression, dArr2, predict, new RegressionMetrics(nanoTime2, (System.nanoTime() - nanoTime3) / 1000000.0d, dArr2.length, RSS.of(dArr2, predict), MSE.of(dArr2, predict), RMSE.of(dArr2, predict), MAD.of(dArr2, predict), R2.of(dArr2, predict)));
    }

    public static <M extends DataFrameRegression> RegressionValidations<M> of(Bag[] bagArr, Formula formula, DataFrame dataFrame, BiFunction<Formula, DataFrame, M> biFunction) {
        ArrayList arrayList = new ArrayList(bagArr.length);
        for (Bag bag : bagArr) {
            arrayList.add(of(formula, dataFrame.of(bag.samples), dataFrame.of(bag.oob), biFunction));
        }
        return new RegressionValidations<>(arrayList);
    }

    public static <T, M extends Regression<T>> RegressionValidations<M> of(Bag[] bagArr, T[] tArr, double[] dArr, BiFunction<T[], double[], M> biFunction) {
        ArrayList arrayList = new ArrayList(bagArr.length);
        for (Bag bag : bagArr) {
            arrayList.add(of(MathEx.slice(tArr, bag.samples), MathEx.slice(dArr, bag.samples), MathEx.slice(tArr, bag.oob), MathEx.slice(dArr, bag.oob), biFunction));
        }
        return new RegressionValidations<>(arrayList);
    }

    public String toString() {
        return this.metrics.toString();
    }
}
