/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import org.apache.spark.ml.feature.Instance;
import org.apache.spark.mllib.linalg.BLAS$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.util.MLUtils$;
import scala.Array$;
import scala.Function0;
import scala.Function2;
import scala.MatchError;
import scala.NotImplementedError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.StringBuilder;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;

@ScalaSignature(bytes="\u0006\u0001a4A!\u0001\u0002\u0005\u001b\t\u0011Bj\\4jgRL7-Q4he\u0016<\u0017\r^8s\u0015\t\u0019A!\u0001\bdY\u0006\u001c8/\u001b4jG\u0006$\u0018n\u001c8\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0003ta\u0006\u00148N\u0003\u0002\n\u0015\u00051\u0011\r]1dQ\u0016T\u0011aC\u0001\u0004_J<7\u0001A\n\u0004\u00019!\u0002CA\b\u0013\u001b\u0005\u0001\"\"A\t\u0002\u000bM\u001c\u0017\r\\1\n\u0005M\u0001\"AB!osJ+g\r\u0005\u0002\u0010+%\u0011a\u0003\u0005\u0002\r'\u0016\u0014\u0018.\u00197ju\u0006\u0014G.\u001a\u0005\t1\u0001\u0011\t\u0011)A\u00053\u0005a1m\\3gM&\u001c\u0017.\u001a8ugB\u0011!dH\u0007\u00027)\u0011A$H\u0001\u0007Y&t\u0017\r\\4\u000b\u0005y1\u0011!B7mY&\u0014\u0017B\u0001\u0011\u001c\u0005\u00191Vm\u0019;pe\"A!\u0005\u0001B\u0001B\u0003%1%\u0001\u0006ok6\u001cE.Y:tKN\u0004\"a\u0004\u0013\n\u0005\u0015\u0002\"aA%oi\"Aq\u0005\u0001B\u0001B\u0003%\u0001&\u0001\u0007gSRLe\u000e^3sG\u0016\u0004H\u000f\u0005\u0002\u0010S%\u0011!\u0006\u0005\u0002\b\u0005>|G.Z1o\u0011!a\u0003A!A!\u0002\u0013i\u0013a\u00034fCR,(/Z:Ti\u0012\u00042a\u0004\u00181\u0013\ty\u0003CA\u0003BeJ\f\u0017\u0010\u0005\u0002\u0010c%\u0011!\u0007\u0005\u0002\u0007\t>,(\r\\3\t\u0011Q\u0002!\u0011!Q\u0001\n5\nABZ3biV\u0014Xm]'fC:DQA\u000e\u0001\u0005\u0002]\na\u0001P5oSRtDC\u0002\u001d;wqjd\b\u0005\u0002:\u00015\t!\u0001C\u0003\u0019k\u0001\u0007\u0011\u0004C\u0003#k\u0001\u00071\u0005C\u0003(k\u0001\u0007\u0001\u0006C\u0003-k\u0001\u0007Q\u0006C\u00035k\u0001\u0007Q\u0006C\u0004A\u0001\u0001\u0007I\u0011B!\u0002\u0013],\u0017n\u001a5u'VlW#\u0001\u0019\t\u000f\r\u0003\u0001\u0019!C\u0005\t\u0006iq/Z5hQR\u001cV/\\0%KF$\"!\u0012%\u0011\u0005=1\u0015BA$\u0011\u0005\u0011)f.\u001b;\t\u000f%\u0013\u0015\u0011!a\u0001a\u0005\u0019\u0001\u0010J\u0019\t\r-\u0003\u0001\u0015)\u00031\u0003)9X-[4iiN+X\u000e\t\u0005\b\u001b\u0002\u0001\r\u0011\"\u0003B\u0003\u001dawn]:Tk6Dqa\u0014\u0001A\u0002\u0013%\u0001+A\u0006m_N\u001c8+^7`I\u0015\fHCA#R\u0011\u001dIe*!AA\u0002ABaa\u0015\u0001!B\u0013\u0001\u0014\u0001\u00037pgN\u001cV/\u001c\u0011\t\u000fU\u0003!\u0019!C\u0005-\u0006\t2m\\3gM&\u001c\u0017.\u001a8ug\u0006\u0013(/Y=\u0016\u00035Ba\u0001\u0017\u0001!\u0002\u0013i\u0013AE2pK\u001a4\u0017nY5f]R\u001c\u0018I\u001d:bs\u0002BqA\u0017\u0001C\u0002\u0013%1,A\u0002eS6,\u0012a\t\u0005\u0007;\u0002\u0001\u000b\u0011B\u0012\u0002\t\u0011LW\u000e\t\u0005\b?\u0002\u0011\r\u0011\"\u0003W\u0003A9'/\u00193jK:$8+^7BeJ\f\u0017\u0010\u0003\u0004b\u0001\u0001\u0006I!L\u0001\u0012OJ\fG-[3oiN+X.\u0011:sCf\u0004\u0003\"B2\u0001\t\u0003!\u0017aA1eIR\u0011QMZ\u0007\u0002\u0001!)qM\u0019a\u0001Q\u0006A\u0011N\\:uC:\u001cW\r\u0005\u0002jY6\t!N\u0003\u0002l\t\u00059a-Z1ukJ,\u0017BA7k\u0005!Ien\u001d;b]\u000e,\u0007\"B8\u0001\t\u0003\u0001\u0018!B7fe\u001e,GCA3r\u0011\u0015\u0011h\u000e1\u00019\u0003\u0015yG\u000f[3s\u0011\u0015!\b\u0001\"\u0001B\u0003\u0011awn]:\t\u000bY\u0004A\u0011A<\u0002\u0011\u001d\u0014\u0018\rZ5f]R,\u0012!\u0007")
public class LogisticAggregator
implements Serializable {
    private final int numClasses;
    private final boolean fitIntercept;
    public final double[] org$apache$spark$ml$classification$LogisticAggregator$$featuresStd;
    private double org$apache$spark$ml$classification$LogisticAggregator$$weightSum;
    private double lossSum;
    private final double[] coefficientsArray;
    private final int org$apache$spark$ml$classification$LogisticAggregator$$dim;
    private final double[] gradientSumArray;

    public double org$apache$spark$ml$classification$LogisticAggregator$$weightSum() {
        return this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum;
    }

    private void org$apache$spark$ml$classification$LogisticAggregator$$weightSum_$eq(double x$1) {
        this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum = x$1;
    }

    private double lossSum() {
        return this.lossSum;
    }

    private void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] coefficientsArray() {
        return this.coefficientsArray;
    }

    public int org$apache$spark$ml$classification$LogisticAggregator$$dim() {
        return this.org$apache$spark$ml$classification$LogisticAggregator$$dim;
    }

    private double[] gradientSumArray() {
        return this.gradientSumArray;
    }

    public LogisticAggregator add(Instance instance) {
        Instance instance2 = instance;
        if (instance2 != null) {
            double label = instance2.label();
            double weight = instance2.weight();
            Vector features = instance2.features();
            Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$dim() == features.size(), (Function0)new Serializable(this, features){
                public static final long serialVersionUID = 0L;
                private final /* synthetic */ LogisticAggregator $outer;
                private final Vector features$1;

                public final String apply() {
                    return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Dimensions mismatch when adding new instance."})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{" Expecting ", " but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$dim()), BoxesRunTime.boxToInteger((int)this.features$1.size())}))).toString();
                }
                {
                    if ($outer == null) {
                        throw null;
                    }
                    this.$outer = $outer;
                    this.features$1 = features$1;
                }
            });
            Predef$.MODULE$.require(weight >= 0.0, (Function0)new Serializable(this, weight){
                public static final long serialVersionUID = 0L;
                private final double weight$2;

                public final String apply() {
                    return new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"instance weight, ", " has to be >= 0.0"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.weight$2)}));
                }
                {
                    this.weight$2 = weight$2;
                }
            });
            if (weight == 0.0) {
                return this;
            }
            double[] localCoefficientsArray = this.coefficientsArray();
            double[] localGradientSumArray = this.gradientSumArray();
            int n = this.numClasses;
            switch (n) {
                default: {
                    NotImplementedError notImplementedError = new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports binary classification for now.");
                    break;
                }
                case 2: {
                    DoubleRef sum = DoubleRef.create((double)0.0);
                    features.foreachActive((Function2<Object, Object, BoxedUnit>)new Serializable(this, localCoefficientsArray, sum){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ LogisticAggregator $outer;
                        private final double[] localCoefficientsArray$1;
                        private final DoubleRef sum$1;

                        public final void apply(int index2, double value) {
                            this.apply$mcVID$sp(index2, value);
                        }

                        public void apply$mcVID$sp(int index2, double value) {
                            if (this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd[index2] != 0.0 && value != 0.0) {
                                this.sum$1.elem += this.localCoefficientsArray$1[index2] * (value / this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd[index2]);
                            }
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                            this.localCoefficientsArray$1 = localCoefficientsArray$1;
                            this.sum$1 = sum$1;
                        }
                    });
                    double margin = -(sum.elem + (this.fitIntercept ? localCoefficientsArray[this.org$apache$spark$ml$classification$LogisticAggregator$$dim()] : 0.0));
                    double multiplier = weight * (1.0 / (1.0 + package$.MODULE$.exp(margin)) - label);
                    features.foreachActive((Function2<Object, Object, BoxedUnit>)new Serializable(this, localGradientSumArray, multiplier){
                        public static final long serialVersionUID = 0L;
                        private final /* synthetic */ LogisticAggregator $outer;
                        private final double[] localGradientSumArray$1;
                        private final double multiplier$1;

                        public final void apply(int index2, double value) {
                            this.apply$mcVID$sp(index2, value);
                        }

                        public void apply$mcVID$sp(int index2, double value) {
                            if (this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd[index2] != 0.0 && value != 0.0) {
                                this.localGradientSumArray$1[index2] = this.localGradientSumArray$1[index2] + this.multiplier$1 * (value / this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd[index2]);
                            }
                        }
                        {
                            if ($outer == null) {
                                throw null;
                            }
                            this.$outer = $outer;
                            this.localGradientSumArray$1 = localGradientSumArray$1;
                            this.multiplier$1 = multiplier$1;
                        }
                    });
                    if (this.fitIntercept) {
                        localGradientSumArray[this.org$apache$spark$ml$classification$LogisticAggregator$$dim()] = localGradientSumArray[this.org$apache$spark$ml$classification$LogisticAggregator$$dim()] + multiplier;
                    }
                    if (label > 0.0) {
                        this.lossSum_$eq(this.lossSum() + weight * MLUtils$.MODULE$.log1pExp(margin));
                    } else {
                        this.lossSum_$eq(this.lossSum() + weight * (MLUtils$.MODULE$.log1pExp(margin) - margin));
                    }
                    NotImplementedError notImplementedError = BoxedUnit.UNIT;
                }
            }
            this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum_$eq(this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() + weight);
            LogisticAggregator logisticAggregator = this;
            return logisticAggregator;
        }
        throw new MatchError((Object)instance2);
    }

    public LogisticAggregator merge(LogisticAggregator other) {
        Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$dim() == other.org$apache$spark$ml$classification$LogisticAggregator$$dim(), (Function0)new Serializable(this, other){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LogisticAggregator $outer;
            private final LogisticAggregator other$1;

            public final String apply() {
                return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Dimensions mismatch when merging with another "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"LeastSquaresAggregator. Expecting ", " but got ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$dim()), BoxesRunTime.boxToInteger((int)this.other$1.org$apache$spark$ml$classification$LogisticAggregator$$dim())}))).toString();
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
                this.other$1 = other$1;
            }
        });
        if (other.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() != 0.0) {
            this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum_$eq(this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() + other.org$apache$spark$ml$classification$LogisticAggregator$$weightSum());
            this.lossSum_$eq(this.lossSum() + other.lossSum());
            double[] localThisGradientSumArray = this.gradientSumArray();
            double[] localOtherGradientSumArray = other.gradientSumArray();
            int len = localThisGradientSumArray.length;
            for (int i = 0; i < len; ++i) {
                int n = i;
                localThisGradientSumArray[n] = localThisGradientSumArray[n] + localOtherGradientSumArray[i];
            }
        }
        return this;
    }

    public double loss() {
        Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() > 0.0, (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LogisticAggregator $outer;

            public final String apply() {
                return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"The effective number of instances should be "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"greater than 0.0, but ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$weightSum())}))).toString();
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
            }
        });
        return this.lossSum() / this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum();
    }

    /*
     * WARNING - void declaration
     */
    public Vector gradient() {
        void var1_1;
        Predef$.MODULE$.require(this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum() > 0.0, (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ LogisticAggregator $outer;

            public final String apply() {
                return new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"The effective number of instances should be "})).s((Seq)Nil$.MODULE$)).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"greater than 0.0, but ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.$outer.org$apache$spark$ml$classification$LogisticAggregator$$weightSum())}))).toString();
            }
            {
                if ($outer == null) {
                    throw null;
                }
                this.$outer = $outer;
            }
        });
        Vector result = Vectors$.MODULE$.dense((double[])this.gradientSumArray().clone());
        BLAS$.MODULE$.scal(1.0 / this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum(), result);
        return var1_1;
    }

    public LogisticAggregator(Vector coefficients, int numClasses, boolean fitIntercept, double[] featuresStd, double[] featuresMean) {
        this.numClasses = numClasses;
        this.fitIntercept = fitIntercept;
        this.org$apache$spark$ml$classification$LogisticAggregator$$featuresStd = featuresStd;
        this.org$apache$spark$ml$classification$LogisticAggregator$$weightSum = 0.0;
        this.lossSum = 0.0;
        Vector vector = coefficients;
        if (vector instanceof DenseVector) {
            DenseVector denseVector = (DenseVector)vector;
            double[] dArray = denseVector.values();
            this.coefficientsArray = dArray;
            this.org$apache$spark$ml$classification$LogisticAggregator$$dim = fitIntercept ? this.coefficientsArray().length - 1 : this.coefficientsArray().length;
            this.gradientSumArray = (double[])Array$.MODULE$.ofDim(this.coefficientsArray().length, ClassTag$.MODULE$.Double());
            return;
        }
        throw new IllegalArgumentException(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"coefficients only supports dense vector but got type ", "."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{coefficients.getClass()})));
    }
}

