package smile.projection;

import java.io.Serializable;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

/* loaded from: classes6.dex */
public class GHA implements LinearProjection, Serializable {
    private static final long serialVersionUID = 2;
    private int n;
    private int p;
    private Matrix projection;
    private double r;
    private double[] wy;
    private double[] y;

    public GHA(int i, int i2, double d) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid dimension of input space: " + i);
        }
        if (i2 < 1 || i2 > i) {
            throw new IllegalArgumentException("Invalid dimension of feature space: " + i2);
        }
        this.n = i;
        this.p = i2;
        this.r = d;
        this.y = new double[i2];
        this.wy = new double[i];
        this.projection = new Matrix(i2, i);
        for (int i3 = 0; i3 < i2; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                this.projection.set(i3, i4, MathEx.random() * 0.1d);
            }
        }
    }

    public GHA(double[][] dArr, double d) {
        this.p = dArr.length;
        this.n = dArr[0].length;
        this.r = d;
        this.y = new double[this.p];
        this.wy = new double[this.n];
        this.projection = new Matrix(dArr);
    }

    public double getLearningRate() {
        return this.r;
    }

    @Override // smile.projection.LinearProjection
    public Matrix getProjection() {
        return this.projection;
    }

    public GHA setLearningRate(double d) {
        this.r = d;
        return this;
    }

    public double update(double[] dArr) {
        if (dArr.length != this.n) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.n)));
        }
        this.projection.mv(dArr, this.y);
        for (int i = 0; i < this.p; i++) {
            for (int i2 = 0; i2 < this.n; i2++) {
                double d = dArr[i2];
                for (int i3 = 0; i3 <= i; i3++) {
                    d -= this.projection.get(i3, i2) * this.y[i3];
                }
                this.projection.add(i, i2, this.r * this.y[i] * d);
                if (Double.isInfinite(this.projection.get(i, i2))) {
                    throw new IllegalStateException("GHA lost convergence. Lower learning rate?");
                }
            }
        }
        this.projection.mv(dArr, this.y);
        this.projection.tv(this.y, this.wy);
        return MathEx.squaredDistance(dArr, this.wy);
    }
}
