package smile.classification;

import com.github.mikephil.charting.utils.Utils;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.IntFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.util.IntSet;

/* loaded from: classes6.dex */
public class OneVersusOne<T> implements SoftClassifier<T> {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) OneVersusOne.class);
    private static final long serialVersionUID = 2;
    private Classifier<T>[][] classifiers;
    private int k;
    private IntSet labels;
    private PlattScaling[][] platts;

    public OneVersusOne(Classifier<T>[][] classifierArr, PlattScaling[][] plattScalingArr) {
        this(classifierArr, plattScalingArr, IntSet.of(classifierArr.length));
    }

    public OneVersusOne(Classifier<T>[][] classifierArr, PlattScaling[][] plattScalingArr, IntSet intSet) {
        this.classifiers = classifierArr;
        this.platts = plattScalingArr;
        this.k = classifierArr.length;
        this.labels = intSet;
    }

    private void coupling(double[][] dArr, double[] dArr2) {
        double d;
        long j;
        double[][] dArr3 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, this.k, this.k);
        double[] dArr4 = new double[this.k];
        double d2 = 0.005d / this.k;
        int i = 0;
        while (true) {
            int i2 = this.k;
            d = Utils.DOUBLE_EPSILON;
            j = 4607182418800017408L;
            if (i >= i2) {
                break;
            }
            dArr2[i] = 1.0d / this.k;
            dArr3[i][i] = 0.0d;
            for (int i3 = 0; i3 < i; i3++) {
                double[] dArr5 = dArr3[i];
                dArr5[i] = dArr5[i] + (dArr[i3][i] * dArr[i3][i]);
                dArr3[i][i3] = dArr3[i3][i];
            }
            for (int i4 = i + 1; i4 < this.k; i4++) {
                double[] dArr6 = dArr3[i];
                dArr6[i] = dArr6[i] + (dArr[i4][i] * dArr[i4][i]);
                dArr3[i][i4] = (-dArr[i4][i]) * dArr[i][i4];
            }
            i++;
        }
        int i5 = 0;
        int max = Math.max(100, this.k);
        while (i5 < max) {
            double d3 = Utils.DOUBLE_EPSILON;
            for (int i6 = 0; i6 < this.k; i6++) {
                dArr4[i6] = d;
                for (int i7 = 0; i7 < this.k; i7++) {
                    dArr4[i6] = dArr4[i6] + (dArr3[i6][i7] * dArr2[i7]);
                }
                d3 += dArr2[i6] * dArr4[i6];
            }
            double d4 = Utils.DOUBLE_EPSILON;
            for (int i8 = 0; i8 < this.k; i8++) {
                double abs = Math.abs(dArr4[i8] - d3);
                if (abs > d4) {
                    d4 = abs;
                }
            }
            if (d4 < d2) {
                break;
            }
            int i9 = 0;
            while (i9 < this.k) {
                double d5 = ((-dArr4[i9]) + d3) / dArr3[i9][i9];
                dArr2[i9] = dArr2[i9] + d5;
                d3 = ((d3 + (((dArr3[i9][i9] * d5) + (dArr4[i9] * 2.0d)) * d5)) / (d5 + 1.0d)) / (d5 + 1.0d);
                for (int i10 = 0; i10 < this.k; i10++) {
                    dArr4[i10] = (dArr4[i10] + (dArr3[i9][i10] * d5)) / (d5 + 1.0d);
                    dArr2[i10] = dArr2[i10] / (d5 + 1.0d);
                }
                i9++;
                j = 4607182418800017408L;
            }
            i5++;
            d = Utils.DOUBLE_EPSILON;
        }
        if (i5 >= max) {
            logger.warn("coupling reaches maximal iterations");
        }
    }

    public static DataFrameClassifier fit(final Formula formula, DataFrame dataFrame, final BiFunction<Formula, DataFrame, DataFrameClassifier> biFunction) {
        OneVersusOne fit = fit((Tuple[]) dataFrame.stream().toArray(new IntFunction() { // from class: smile.classification.OneVersusOne$$ExternalSyntheticLambda0
            @Override // java.util.function.IntFunction
            public final Object apply(int i) {
                return OneVersusOne.lambda$fit$0(i);
            }
        }), formula.y(dataFrame).toIntArray(), 1, 0, new BiFunction() { // from class: smile.classification.OneVersusOne$$ExternalSyntheticLambda1
            @Override // java.util.function.BiFunction
            public final Object apply(Object obj, Object obj2) {
                return OneVersusOne.lambda$fit$1(biFunction, formula, (Tuple[]) obj, (int[]) obj2);
            }
        });
        final StructType schema = formula.x(dataFrame.get(0)).schema();
        return new DataFrameClassifier() { // from class: smile.classification.OneVersusOne.1
            @Override // smile.classification.DataFrameClassifier, smile.feature.TreeSHAP
            public Formula formula() {
                return formula;
            }

            @Override // smile.classification.DataFrameClassifier
            public int predict(Tuple tuple) {
                return OneVersusOne.this.predict((OneVersusOne) tuple);
            }

            @Override // smile.classification.DataFrameClassifier
            public StructType schema() {
                return schema;
            }
        };
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <T> OneVersusOne<T> fit(T[] tArr, int[] iArr, int i, int i2, BiFunction<T[], int[], Classifier<T>> biFunction) {
        if (tArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(iArr.length)));
        }
        ClassLabels fit = ClassLabels.fit(iArr);
        int i3 = fit.k;
        if (i3 <= 2) {
            throw new IllegalArgumentException(String.format("Only %d classes", Integer.valueOf(i3)));
        }
        int[] iArr2 = fit.ni;
        int[] iArr3 = fit.y;
        Classifier[][] classifierArr = new Classifier[i3];
        PlattScaling[][] plattScalingArr = (PlattScaling[][]) null;
        int i4 = 1;
        while (i4 < i3) {
            classifierArr[i4] = new Classifier[i4];
            PlattScaling[][] plattScalingArr2 = plattScalingArr;
            for (int i5 = 0; i5 < i4; i5++) {
                int i6 = iArr2[i5] + iArr2[i4];
                Object[] objArr = (Object[]) Array.newInstance(tArr.getClass().getComponentType(), i6);
                int[] iArr4 = new int[i6];
                int i7 = 0;
                for (int i8 = 0; i8 < iArr3.length; i8++) {
                    if (iArr3[i8] == i4) {
                        objArr[i7] = tArr[i8];
                        iArr4[i7] = i;
                        i7++;
                    } else if (iArr3[i8] == i5) {
                        objArr[i7] = tArr[i8];
                        iArr4[i7] = i2;
                        i7++;
                    }
                }
                classifierArr[i4][i5] = (Classifier) biFunction.apply(objArr, iArr4);
                if (i5 == 0 && i4 == 1) {
                    try {
                        classifierArr[i4][i5].score(objArr[0]);
                        plattScalingArr2 = new PlattScaling[i3];
                    } catch (UnsupportedOperationException e) {
                        logger.info("The classifier doesn't support score function. Don't fit Platt scaling.");
                    }
                }
                if (plattScalingArr2 != null) {
                    if (plattScalingArr2[i4] == null) {
                        plattScalingArr2[i4] = new PlattScaling[i4];
                    }
                    plattScalingArr2[i4][i5] = PlattScaling.fit(classifierArr[i4][i5], objArr, iArr4);
                }
            }
            i4++;
            plattScalingArr = plattScalingArr2;
        }
        return new OneVersusOne<>(classifierArr, plattScalingArr);
    }

    public static <T> OneVersusOne<T> fit(T[] tArr, int[] iArr, BiFunction<T[], int[], Classifier<T>> biFunction) {
        return fit(tArr, iArr, 1, -1, biFunction);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Tuple[] lambda$fit$0(int i) {
        return new Tuple[i];
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ Classifier lambda$fit$1(BiFunction biFunction, Formula formula, Tuple[] tupleArr, int[] iArr) {
        return (Classifier) biFunction.apply(formula, DataFrame.of((List<? extends Tuple>) Arrays.asList(tupleArr)));
    }

    @Override // smile.classification.Classifier
    public int predict(T t) {
        int[] iArr = new int[this.k];
        for (int i = 1; i < this.k; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                if (this.classifiers[i][i2].predict((Classifier<T>) t) > 0) {
                    iArr[i] = iArr[i] + 1;
                } else {
                    iArr[i2] = iArr[i2] + 1;
                }
            }
        }
        return this.labels.valueOf(MathEx.whichMax(iArr));
    }

    @Override // smile.classification.SoftClassifier
    public int predict(T t, double[] dArr) {
        if (this.platts == null) {
            throw new UnsupportedOperationException("Platt scaling is not available");
        }
        double[][] dArr2 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, this.k, this.k);
        for (int i = 1; i < this.k; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                dArr2[i][i2] = this.platts[i][i2].scale(this.classifiers[i][i2].score(t));
                dArr2[i2][i] = 1.0d - dArr2[i][i2];
            }
        }
        coupling(dArr2, dArr);
        return this.labels.valueOf(MathEx.whichMax(dArr));
    }
}
