package smile.regression;

import com.github.mikephil.charting.utils.Utils;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Properties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.blas.Transpose;
import smile.math.matrix.BiconjugateGradient;
import smile.math.matrix.DMatrix;
import smile.math.matrix.Matrix;
import smile.math.matrix.Preconditioner;

/* loaded from: classes6.dex */
public class LASSO {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) LASSO.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: classes6.dex */
    public static class PCGMatrix extends DMatrix implements Preconditioner {
        Matrix A;
        Matrix AtA;
        double[] atax;
        double[] ax;
        double[] d1;
        double[] d2;
        int p;
        double[] prb;
        double[] prs;

        PCGMatrix(Matrix matrix, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
            this.A = matrix;
            this.d1 = dArr;
            this.d2 = dArr2;
            this.prb = dArr3;
            this.prs = dArr4;
            int nrows = matrix.nrows();
            this.p = matrix.ncols();
            this.ax = new double[nrows];
            this.atax = new double[this.p];
            if (matrix.ncols() >= 10000 || !(matrix instanceof Matrix)) {
                return;
            }
            this.AtA = matrix.ata();
        }

        @Override // smile.math.matrix.DMatrix
        public double get(int i, int i2) {
            throw new UnsupportedOperationException();
        }

        @Override // smile.math.matrix.DMatrix
        public void mv(Transpose transpose, double d, double[] dArr, double d2, double[] dArr2) {
            throw new UnsupportedOperationException();
        }

        @Override // smile.math.matrix.IMatrix
        public void mv(double[] dArr, int i, int i2) {
            throw new UnsupportedOperationException();
        }

        @Override // smile.math.matrix.DMatrix, smile.math.matrix.IMatrix
        public void mv(double[] dArr, double[] dArr2) {
            if (this.AtA != null) {
                this.AtA.mv(dArr, this.atax);
            } else {
                this.A.mv(dArr, this.ax);
                this.A.tv(this.ax, this.atax);
            }
            for (int i = 0; i < this.p; i++) {
                dArr2[i] = (this.atax[i] * 2.0d) + (this.d1[i] * dArr[i]) + (this.d2[i] * dArr[this.p + i]);
                dArr2[this.p + i] = (this.d2[i] * dArr[i]) + (this.d1[i] * dArr[this.p + i]);
            }
        }

        @Override // smile.math.matrix.IMatrix
        public int ncols() {
            return this.p * 2;
        }

        @Override // smile.math.matrix.IMatrix
        public int nrows() {
            return this.p * 2;
        }

        @Override // smile.math.matrix.DMatrix
        public DMatrix set(int i, int i2, double d) {
            throw new UnsupportedOperationException();
        }

        @Override // smile.math.matrix.IMatrix
        public long size() {
            return this.A.size();
        }

        @Override // smile.math.matrix.Preconditioner
        public void solve(double[] dArr, double[] dArr2) {
            for (int i = 0; i < this.p; i++) {
                dArr2[i] = ((this.d1[i] * dArr[i]) - (this.d2[i] * dArr[this.p + i])) / this.prs[i];
                dArr2[this.p + i] = (((-this.d2[i]) * dArr[i]) + (this.prb[i] * dArr[this.p + i])) / this.prs[i];
            }
        }

        @Override // smile.math.matrix.IMatrix
        public void tv(double[] dArr, int i, int i2) {
            throw new UnsupportedOperationException();
        }

        @Override // smile.math.matrix.DMatrix, smile.math.matrix.IMatrix
        public void tv(double[] dArr, double[] dArr2) {
            mv(dArr, dArr2);
        }
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d) {
        return fit(formula, dataFrame, d, 1.0E-4d, 1000);
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d, double d2, int i) {
        Formula expand = formula.expand(dataFrame.schema());
        StructType bind = expand.bind(dataFrame.schema());
        Matrix matrix = expand.matrix(dataFrame, false);
        double[] doubleArray = expand.y(dataFrame).toDoubleArray();
        double[] colMeans = matrix.colMeans();
        double[] colSds = matrix.colSds();
        for (int i2 = 0; i2 < colSds.length; i2++) {
            if (MathEx.isZero(colSds[i2])) {
                throw new IllegalArgumentException(String.format("The column '%s' is constant", matrix.colName(i2)));
            }
        }
        double[] train = train(matrix.scale(colMeans, colSds), doubleArray, d, d2, i);
        int length = train.length;
        for (int i3 = 0; i3 < length; i3++) {
            train[i3] = train[i3] / colSds[i3];
        }
        return new LinearModel(expand, bind, matrix, doubleArray, train, MathEx.mean(doubleArray) - MathEx.dot(train, colMeans));
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Double.valueOf(properties.getProperty("smile.lasso.lambda", "1")).doubleValue(), Double.valueOf(properties.getProperty("smile.lasso.tolerance", "1E-4")).doubleValue(), Integer.valueOf(properties.getProperty("smile.lasso.max.iterations", "1000")).intValue());
    }

    private static double sumlogneg(double[][] dArr) {
        int length = dArr[0].length;
        double d = Utils.DOUBLE_EPSILON;
        for (double[] dArr2 : dArr) {
            for (int i = 0; i < length; i++) {
                d += Math.log(-dArr2[i]);
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[] train(Matrix matrix, double[] dArr, double d, double d2, int i) {
        double[] dArr2;
        double[] dArr3;
        int i2;
        int i3;
        int i4;
        double[] dArr4;
        double[] dArr5;
        double[] dArr6;
        int i5 = i;
        if (d < Utils.DOUBLE_EPSILON) {
            throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + d);
        }
        if (d2 <= Utils.DOUBLE_EPSILON) {
            throw new IllegalArgumentException("Invalid tolerance: " + d2);
        }
        if (i5 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i5);
        }
        char c = 0;
        int nrows = matrix.nrows();
        int ncols = matrix.ncols();
        double[] dArr7 = new double[nrows];
        double mean = MathEx.mean(dArr);
        for (int i6 = 0; i6 < nrows; i6++) {
            dArr7[i6] = dArr[i6] - mean;
        }
        double min = Math.min(Math.max(1.0d, 1.0d / d), (ncols * 2) / 0.001d);
        double d3 = Double.POSITIVE_INFINITY;
        double[] dArr8 = new double[ncols];
        double[] dArr9 = new double[ncols];
        double[] dArr10 = new double[nrows];
        double[][] dArr11 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, 2, ncols);
        Arrays.fill(dArr9, 1.0d);
        for (int i7 = 0; i7 < ncols; i7++) {
            dArr11[0][i7] = dArr8[i7] - dArr9[i7];
            dArr11[1][i7] = (-dArr8[i7]) - dArr9[i7];
        }
        double[] dArr12 = new double[ncols];
        double[][] dArr13 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, 2, ncols);
        double[] dArr14 = new double[ncols];
        double[] dArr15 = new double[ncols];
        double[] dArr16 = new double[nrows];
        double[] dArr17 = new double[ncols];
        double[] dArr18 = new double[ncols];
        Arrays.fill(dArr17, 2.0d);
        double[] dArr19 = new double[nrows];
        double[] dArr20 = new double[ncols];
        double[] dArr21 = new double[ncols];
        double[] dArr22 = new double[ncols];
        double[] dArr23 = new double[ncols];
        double[] dArr24 = new double[ncols];
        double[][] dArr25 = (double[][]) Array.newInstance((Class<?>) Double.TYPE, 2, ncols);
        double[] dArr26 = new double[ncols];
        double[] dArr27 = new double[ncols];
        double[] dArr28 = dArr10;
        double[] dArr29 = new double[ncols * 2];
        double[] dArr30 = new double[ncols * 2];
        double[] dArr31 = dArr7;
        int i8 = ncols;
        int i9 = nrows;
        PCGMatrix pCGMatrix = new PCGMatrix(matrix, dArr23, dArr24, dArr26, dArr27);
        int i10 = 0;
        double d4 = Double.NEGATIVE_INFINITY;
        double d5 = min;
        while (true) {
            if (i10 > i5) {
                break;
            }
            matrix.mv(dArr8, dArr28);
            for (int i11 = 0; i11 < i9; i11++) {
                dArr28[i11] = dArr28[i11] - dArr31[i11];
                dArr19[i11] = dArr28[i11] * 2.0d;
            }
            matrix.tv(dArr19, dArr20);
            double normInf = MathEx.normInf(dArr20);
            if (normInf > d) {
                double d6 = d / normInf;
                for (int i12 = 0; i12 < i9; i12++) {
                    dArr19[i12] = dArr19[i12] * d6;
                }
            }
            double dot = MathEx.dot(dArr28, dArr28) + (MathEx.norm1(dArr8) * d);
            double[] dArr32 = dArr19;
            double[] dArr33 = dArr20;
            double max = Math.max((MathEx.dot(dArr19, dArr19) * (-0.25d)) - MathEx.dot(dArr19, dArr31), d4);
            if (i10 % 10 == 0) {
                dArr2 = dArr32;
                dArr3 = dArr28;
                logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", Integer.valueOf(i10), Double.valueOf(dot), Double.valueOf(max)));
            } else {
                dArr2 = dArr32;
                dArr3 = dArr28;
            }
            double d7 = dot - max;
            double[] dArr34 = dArr31;
            double[] dArr35 = dArr3;
            if (d7 / max < d2) {
                logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", Integer.valueOf(i10), Double.valueOf(dot), Double.valueOf(max)));
                break;
            }
            if (d3 >= 0.5d) {
                i2 = i9;
                d5 = Math.max(Math.min(((i8 * 2) * 2) / d7, d5 * 2.0d), d5);
            } else {
                i2 = i9;
            }
            int i13 = 0;
            while (true) {
                i3 = i8;
                if (i13 >= i3) {
                    break;
                }
                double d8 = 1.0d / (dArr9[i13] + dArr8[i13]);
                double d9 = 1.0d / (dArr9[i13] - dArr8[i13]);
                dArr21[i13] = d8;
                dArr22[i13] = d9;
                dArr23[i13] = ((d8 * d8) + (d9 * d9)) / d5;
                dArr24[i13] = ((d8 * d8) - (d9 * d9)) / d5;
                i13++;
                i8 = i3;
                max = max;
            }
            double d10 = max;
            char c2 = 0;
            matrix.tv(dArr35, dArr25[0]);
            int i14 = 0;
            while (i14 < i3) {
                dArr25[c2][i14] = (dArr25[c2][i14] * 2.0d) - ((dArr21[i14] - dArr22[i14]) / d5);
                dArr25[1][i14] = d - ((dArr21[i14] + dArr22[i14]) / d5);
                dArr30[i14] = -dArr25[0][i14];
                dArr30[i14 + i3] = -dArr25[1][i14];
                i14++;
                c2 = 0;
            }
            for (int i15 = 0; i15 < i3; i15++) {
                dArr26[i15] = dArr17[i15] + dArr23[i15];
                dArr27[i15] = (dArr26[i15] * dArr23[i15]) - (dArr24[i15] * dArr24[i15]);
            }
            double norm = MathEx.norm(dArr30);
            int i16 = i3;
            double min2 = Math.min(0.1d, (d7 * 0.001d) / Math.min(1.0d, norm));
            if (i10 != 0 && c == 0) {
                min2 *= 0.1d;
            }
            double solve = BiconjugateGradient.solve(pCGMatrix, dArr30, dArr29, pCGMatrix, min2, 1, 5000);
            if (solve > min2) {
                c = 5000;
            }
            int i17 = 0;
            while (true) {
                double d11 = norm;
                i4 = i16;
                if (i17 >= i4) {
                    break;
                }
                double[] dArr36 = dArr29;
                dArr14[i17] = dArr36[i17];
                dArr15[i17] = dArr36[i17 + i4];
                i17++;
                i16 = i4;
                dArr29 = dArr36;
                norm = d11;
            }
            double[] dArr37 = dArr29;
            double dot2 = (MathEx.dot(dArr35, dArr35) + (MathEx.sum(dArr9) * d)) - (sumlogneg(dArr11) / d5);
            d3 = 1.0d;
            double[] dArr38 = dArr30;
            double dot3 = MathEx.dot(dArr38, dArr37);
            int i18 = 0;
            while (true) {
                dArr4 = dArr35;
                if (i18 >= 100) {
                    dArr5 = dArr16;
                    dArr6 = dArr18;
                    break;
                }
                int i19 = 0;
                while (i19 < i4) {
                    dArr18[i19] = dArr8[i19] + (dArr14[i19] * d3);
                    dArr12[i19] = dArr9[i19] + (dArr15[i19] * d3);
                    dArr13[0][i19] = dArr18[i19] - dArr12[i19];
                    dArr13[1][i19] = (-dArr18[i19]) - dArr12[i19];
                    i19++;
                    solve = solve;
                }
                double d12 = solve;
                if (MathEx.max(dArr13) < Utils.DOUBLE_EPSILON) {
                    dArr5 = dArr16;
                    dArr6 = dArr18;
                    matrix.mv(dArr6, dArr5);
                    for (int i20 = 0; i20 < i2; i20++) {
                        dArr5[i20] = dArr5[i20] - dArr34[i20];
                    }
                    if (((MathEx.dot(dArr5, dArr5) + (MathEx.sum(dArr12) * d)) - (sumlogneg(dArr13) / d5)) - dot2 <= 0.01d * d3 * dot3) {
                        break;
                    }
                } else {
                    dArr5 = dArr16;
                    dArr6 = dArr18;
                }
                d3 *= 0.5d;
                i18++;
                dArr18 = dArr6;
                dArr16 = dArr5;
                dArr35 = dArr4;
                solve = d12;
            }
            if (i18 == 100) {
                logger.error("LASSO: Too many iterations of line search.");
                break;
            }
            System.arraycopy(dArr6, 0, dArr8, 0, i4);
            System.arraycopy(dArr12, 0, dArr9, 0, i4);
            System.arraycopy(dArr13[0], 0, dArr11[0], 0, i4);
            System.arraycopy(dArr13[1], 0, dArr11[1], 0, i4);
            i10++;
            i8 = i4;
            i9 = i2;
            dArr16 = dArr5;
            dArr30 = dArr38;
            dArr19 = dArr2;
            dArr31 = dArr34;
            d5 = d5;
            dArr20 = dArr33;
            d4 = d10;
            i5 = i;
            dArr18 = dArr6;
            dArr28 = dArr4;
            dArr29 = dArr37;
        }
        if (i10 == i) {
            logger.error("LASSO: Too many iterations.");
        }
        return dArr8;
    }
}
