/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.arbiter.scoring.impl;

import lombok.NonNull;
import org.deeplearning4j.arbiter.scoring.impl.BaseNetScoreFunction;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class ROCScoreFunction
extends BaseNetScoreFunction {
    protected ROCType type;
    protected Metric metric;

    public ROCScoreFunction(@NonNull ROCType type, @NonNull Metric metric) {
        if (type == null) {
            throw new NullPointerException("type is marked @NonNull but is null");
        }
        if (metric == null) {
            throw new NullPointerException("metric is marked @NonNull but is null");
        }
        this.type = type;
        this.metric = metric;
    }

    public String toString() {
        return "ROCScoreFunction(type=" + (Object)((Object)this.type) + ",metric=" + (Object)((Object)this.metric) + ")";
    }

    @Override
    public double score(MultiLayerNetwork net, DataSetIterator iterator) {
        switch (this.type) {
            case ROC: {
                ROC r = net.evaluateROC(iterator);
                return this.metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
            }
            case BINARY: {
                ROCBinary r2 = ((ROCBinary[])net.doEvaluation(iterator, (IEvaluation[])new ROCBinary[]{new ROCBinary()}))[0];
                return this.metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR();
            }
            case MULTICLASS: {
                ROCMultiClass r3 = net.evaluateROCMultiClass(iterator);
                return this.metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
            }
        }
        throw new RuntimeException("Unknown type: " + (Object)((Object)this.type));
    }

    @Override
    public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) {
        return this.score(net, (DataSetIterator)new MultiDataSetWrapperIterator(iterator));
    }

    @Override
    public double score(ComputationGraph graph, DataSetIterator iterator) {
        return this.score(graph, (MultiDataSetIterator)new MultiDataSetIteratorAdapter(iterator));
    }

    @Override
    public double score(ComputationGraph net, MultiDataSetIterator iterator) {
        switch (this.type) {
            case ROC: {
                ROC r = net.evaluateROC(iterator);
                return this.metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR();
            }
            case BINARY: {
                ROCBinary r2 = ((ROCBinary[])net.doEvaluation(iterator, (IEvaluation[])new ROCBinary[]{new ROCBinary()}))[0];
                return this.metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR();
            }
            case MULTICLASS: {
                ROCMultiClass r3 = net.evaluateROCMultiClass(iterator, 0);
                return this.metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR();
            }
        }
        throw new RuntimeException("Unknown type: " + (Object)((Object)this.type));
    }

    public boolean minimize() {
        return false;
    }

    public ROCType getType() {
        return this.type;
    }

    public Metric getMetric() {
        return this.metric;
    }

    public void setType(ROCType type) {
        this.type = type;
    }

    public void setMetric(Metric metric) {
        this.metric = metric;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ROCScoreFunction)) {
            return false;
        }
        ROCScoreFunction other = (ROCScoreFunction)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        ROCType this$type = this.getType();
        ROCType other$type = other.getType();
        if (this$type == null ? other$type != null : !((Object)((Object)this$type)).equals((Object)other$type)) {
            return false;
        }
        Metric this$metric = this.getMetric();
        Metric other$metric = other.getMetric();
        return !(this$metric == null ? other$metric != null : !((Object)((Object)this$metric)).equals((Object)other$metric));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof ROCScoreFunction;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        ROCType $type = this.getType();
        result = result * 59 + ($type == null ? 43 : ((Object)((Object)$type)).hashCode());
        Metric $metric = this.getMetric();
        result = result * 59 + ($metric == null ? 43 : ((Object)((Object)$metric)).hashCode());
        return result;
    }

    protected ROCScoreFunction() {
    }

    public static enum Metric {
        AUC,
        AUPRC;

    }

    public static enum ROCType {
        ROC,
        BINARY,
        MULTICLASS;

    }
}

