/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.graph;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.DoubleFlatMapFunction;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.rdd.RDD;
import org.datavec.spark.util.BroadcastHadoopConfigHolder;
import org.datavec.spark.util.SerializableHadoopConfig;
import org.deeplearning4j.api.loader.DataSetLoader;
import org.deeplearning4j.api.loader.MultiDataSetLoader;
import org.deeplearning4j.api.loader.impl.SerializedDataSetLoader;
import org.deeplearning4j.api.loader.impl.SerializedMultiDataSetLoader;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.impl.SparkListenable;
import org.deeplearning4j.spark.impl.common.reduce.IntDoubleReduceFunction;
import org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.graph.dataset.PairDataSetToMultiDataSetFn;
import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSFlatMapFunction;
import org.deeplearning4j.spark.impl.graph.evaluation.IEvaluateMDSPathsFlatMapFunction;
import org.deeplearning4j.spark.impl.graph.scoring.ArrayPairToPair;
import org.deeplearning4j.spark.impl.graph.scoring.GraphFeedForwardWithKeyFunction;
import org.deeplearning4j.spark.impl.graph.scoring.PairToArrayPair;
import org.deeplearning4j.spark.impl.graph.scoring.ScoreExamplesFunction;
import org.deeplearning4j.spark.impl.graph.scoring.ScoreExamplesWithKeyFunction;
import org.deeplearning4j.spark.impl.graph.scoring.ScoreFlatMapFunctionCGDataSet;
import org.deeplearning4j.spark.impl.graph.scoring.ScoreFlatMapFunctionCGMultiDataSet;
import org.deeplearning4j.spark.impl.multilayer.evaluation.IEvaluateAggregateFunction;
import org.deeplearning4j.spark.impl.multilayer.evaluation.IEvaluateFlatMapFunction;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.Heartbeat;
import org.nd4j.linalg.heartbeat.reports.Environment;
import org.nd4j.linalg.heartbeat.reports.Event;
import org.nd4j.linalg.heartbeat.reports.Task;
import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class SparkComputationGraph
extends SparkListenable {
    private static final Logger log = LoggerFactory.getLogger(SparkComputationGraph.class);
    public static final int DEFAULT_ROC_THRESHOLD_STEPS = 32;
    public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 64;
    public static final int DEFAULT_EVAL_WORKERS = 4;
    private transient JavaSparkContext sc;
    private ComputationGraphConfiguration conf;
    private ComputationGraph network;
    private double lastScore;
    private int defaultEvaluationWorkers = 4;
    private transient AtomicInteger iterationsCount = new AtomicInteger(0);

    public SparkComputationGraph(SparkContext sparkContext, ComputationGraph network, TrainingMaster trainingMaster) {
        this(new JavaSparkContext(sparkContext), network, trainingMaster);
    }

    public SparkComputationGraph(JavaSparkContext javaSparkContext, ComputationGraph network, TrainingMaster trainingMaster) {
        this.sc = javaSparkContext;
        this.trainingMaster = trainingMaster;
        this.conf = network.getConfiguration().clone();
        this.network = network;
        this.network.init();
        SparkUtils.checkKryoConfiguration(javaSparkContext, log);
    }

    public SparkComputationGraph(SparkContext sparkContext, ComputationGraphConfiguration conf, TrainingMaster trainingMaster) {
        this(new JavaSparkContext(sparkContext), conf, trainingMaster);
    }

    public SparkComputationGraph(JavaSparkContext sparkContext, ComputationGraphConfiguration conf, TrainingMaster trainingMaster) {
        this.sc = sparkContext;
        this.trainingMaster = trainingMaster;
        this.conf = conf.clone();
        this.network = new ComputationGraph(conf);
        this.network.init();
        SparkUtils.checkKryoConfiguration(sparkContext, log);
    }

    public JavaSparkContext getSparkContext() {
        return this.sc;
    }

    public void setCollectTrainingStats(boolean collectTrainingStats) {
        this.trainingMaster.setCollectTrainingStats(collectTrainingStats);
    }

    public SparkTrainingStats getSparkTrainingStats() {
        return this.trainingMaster.getTrainingStats();
    }

    public ComputationGraph getNetwork() {
        return this.network;
    }

    public TrainingMaster getTrainingMaster() {
        return this.trainingMaster;
    }

    public void setNetwork(ComputationGraph network) {
        this.network = network;
    }

    public int getDefaultEvaluationWorkers() {
        return this.defaultEvaluationWorkers;
    }

    public void setDefaultEvaluationWorkers(int workers) {
        Preconditions.checkArgument((workers > 0 ? 1 : 0) != 0, (String)"Number of workers must be > 0: got %s", (int)workers);
        this.defaultEvaluationWorkers = workers;
    }

    public ComputationGraph fit(RDD<DataSet> rdd) {
        return this.fit((JavaRDD<DataSet>)rdd.toJavaRDD());
    }

    public ComputationGraph fit(JavaRDD<DataSet> rdd) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        this.trainingMaster.executeTraining(this, rdd);
        this.network.incrementEpochCount();
        return this.network;
    }

    public ComputationGraph fit(String path) {
        JavaRDD<String> paths;
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        try {
            paths = SparkUtils.listPaths(this.sc, path);
        }
        catch (IOException e) {
            throw new RuntimeException("Error listing paths in directory", e);
        }
        return this.fitPaths(paths);
    }

    @Deprecated
    public ComputationGraph fit(String path, int minPartitions) {
        return this.fit(path);
    }

    public ComputationGraph fitPaths(JavaRDD<String> paths) {
        return this.fitPaths(paths, (DataSetLoader)new SerializedDataSetLoader());
    }

    public ComputationGraph fitPaths(JavaRDD<String> paths, DataSetLoader loader) {
        this.trainingMaster.executeTrainingPaths(null, this, paths, loader, null);
        this.network.incrementEpochCount();
        return this.network;
    }

    public ComputationGraph fitMultiDataSet(RDD<MultiDataSet> rdd) {
        return this.fitMultiDataSet((JavaRDD<MultiDataSet>)rdd.toJavaRDD());
    }

    public ComputationGraph fitMultiDataSet(JavaRDD<MultiDataSet> rdd) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        this.trainingMaster.executeTrainingMDS(this, rdd);
        this.network.incrementEpochCount();
        return this.network;
    }

    public ComputationGraph fitMultiDataSet(String path) {
        JavaRDD<String> paths;
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            ((GridExecutioner)Nd4j.getExecutioner()).flushQueue();
        }
        try {
            paths = SparkUtils.listPaths(this.sc, path);
        }
        catch (IOException e) {
            throw new RuntimeException("Error listing paths in directory", e);
        }
        return this.fitPathsMultiDataSet(paths);
    }

    public ComputationGraph fitPathsMultiDataSet(JavaRDD<String> paths) {
        return this.fitPaths(paths, (MultiDataSetLoader)new SerializedMultiDataSetLoader());
    }

    public ComputationGraph fitPaths(JavaRDD<String> paths, MultiDataSetLoader loader) {
        this.trainingMaster.executeTrainingPaths(null, this, paths, null, loader);
        this.network.incrementEpochCount();
        return this.network;
    }

    @Deprecated
    public ComputationGraph fitMultiDataSet(String path, int minPartitions) {
        return this.fitMultiDataSet(path);
    }

    public double getScore() {
        return this.lastScore;
    }

    public void setScore(double lastScore) {
        this.lastScore = lastScore;
    }

    public double calculateScore(JavaRDD<DataSet> data, boolean average) {
        return this.calculateScore(data, average, 64);
    }

    public double calculateScore(JavaRDD<DataSet> data, boolean average, int minibatchSize) {
        JavaRDD rdd = data.mapPartitions((FlatMapFunction)new ScoreFlatMapFunctionCGDataSet(this.conf.toJson(), (Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params(false)), minibatchSize));
        Tuple2 countAndSumScores = (Tuple2)rdd.reduce((Function2)new IntDoubleReduceFunction());
        if (average) {
            return (Double)countAndSumScores._2() / (double)((Integer)countAndSumScores._1()).intValue();
        }
        return (Double)countAndSumScores._2();
    }

    public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> data, boolean average) {
        return this.calculateScoreMultiDataSet(data, average, 64);
    }

    public double calculateScoreMultiDataSet(JavaRDD<MultiDataSet> data, boolean average, int minibatchSize) {
        JavaRDD rdd = data.mapPartitions((FlatMapFunction)new ScoreFlatMapFunctionCGMultiDataSet(this.conf.toJson(), (Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params(false)), minibatchSize));
        Tuple2 countAndSumScores = (Tuple2)rdd.reduce((Function2)new IntDoubleReduceFunction());
        if (average) {
            return (Double)countAndSumScores._2() / (double)((Integer)countAndSumScores._1()).intValue();
        }
        return (Double)countAndSumScores._2();
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamplesMultiDataSet((JavaRDD<MultiDataSet>)data.map((Function)new DataSetToMultiDataSetFn()), includeRegularizationTerms);
    }

    public JavaDoubleRDD scoreExamples(JavaRDD<DataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return this.scoreExamplesMultiDataSet((JavaRDD<MultiDataSet>)data.map((Function)new DataSetToMultiDataSetFn()), includeRegularizationTerms, batchSize);
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamplesMultiDataSet(data.mapToPair(new PairDataSetToMultiDataSetFn()), includeRegularizationTerms, 64);
    }

    public <K> JavaPairRDD<K, Double> scoreExamples(JavaPairRDD<K, DataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return this.scoreExamplesMultiDataSet(data.mapToPair(new PairDataSetToMultiDataSetFn()), includeRegularizationTerms, batchSize);
    }

    public JavaDoubleRDD scoreExamplesMultiDataSet(JavaRDD<MultiDataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamplesMultiDataSet(data, includeRegularizationTerms, 64);
    }

    public JavaDoubleRDD scoreExamplesMultiDataSet(JavaRDD<MultiDataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return data.mapPartitionsToDouble((DoubleFlatMapFunction)new ScoreExamplesFunction((Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params()), (Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), includeRegularizationTerms, batchSize));
    }

    public <K> JavaPairRDD<K, Double> scoreExamplesMultiDataSet(JavaPairRDD<K, MultiDataSet> data, boolean includeRegularizationTerms) {
        return this.scoreExamplesMultiDataSet(data, includeRegularizationTerms, 64);
    }

    public <K> JavaPairRDD<K, INDArray> feedForwardWithKeySingle(JavaPairRDD<K, INDArray> featuresData, int batchSize) {
        if (this.network.getNumInputArrays() != 1 || this.network.getNumOutputArrays() != 1) {
            throw new IllegalStateException("Cannot use this method with computation graphs with more than 1 input or output ( has: " + this.network.getNumInputArrays() + " inputs, " + this.network.getNumOutputArrays() + " outputs");
        }
        PairToArrayPair p = new PairToArrayPair();
        JavaPairRDD rdd = featuresData.mapToPair(p);
        return this.feedForwardWithKey(rdd, batchSize).mapToPair(new ArrayPairToPair());
    }

    public <K> JavaPairRDD<K, INDArray[]> feedForwardWithKey(JavaPairRDD<K, INDArray[]> featuresData, int batchSize) {
        return featuresData.mapPartitionsToPair(new GraphFeedForwardWithKeyFunction((Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params()), (Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), batchSize));
    }

    private void update(int mr, long mg) {
        Environment env = EnvironmentUtils.buildEnvironment();
        env.setNumCores(mr);
        env.setAvailableMemory(mg);
        Task task = ModelSerializer.taskByModel((Model)this.network);
        Heartbeat.getInstance().reportEvent(Event.SPARK, env, task);
    }

    public <K> JavaPairRDD<K, Double> scoreExamplesMultiDataSet(JavaPairRDD<K, MultiDataSet> data, boolean includeRegularizationTerms, int batchSize) {
        return data.mapPartitionsToPair(new ScoreExamplesWithKeyFunction((Broadcast<INDArray>)this.sc.broadcast((Object)this.network.params()), (Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), includeRegularizationTerms, batchSize));
    }

    public Evaluation evaluate(String path, DataSetLoader loader) {
        JavaRDD<String> data;
        try {
            data = SparkUtils.listPaths(this.sc, path);
        }
        catch (IOException e) {
            throw new RuntimeException("Error listing files for evaluation of files at path: " + path, e);
        }
        return (Evaluation)this.doEvaluation(data, 4, 64, loader, (MultiDataSetLoader)null, new IEvaluation[]{new Evaluation()})[0];
    }

    public Evaluation evaluate(String path, MultiDataSetLoader loader) {
        JavaRDD<String> data;
        try {
            data = SparkUtils.listPaths(this.sc, path);
        }
        catch (IOException e) {
            throw new RuntimeException("Error listing files for evaluation of files at path: " + path, e);
        }
        return (Evaluation)this.doEvaluation(data, 4, 64, null, loader, new IEvaluation[]{new Evaluation()})[0];
    }

    public <T extends Evaluation> T evaluate(RDD<DataSet> data) {
        return this.evaluate((JavaRDD<DataSet>)data.toJavaRDD());
    }

    public <T extends Evaluation> T evaluate(JavaRDD<DataSet> data) {
        return this.evaluate(data, null);
    }

    public <T extends Evaluation> T evaluate(RDD<DataSet> data, List<String> labelsList) {
        return this.evaluate((JavaRDD<DataSet>)data.toJavaRDD(), labelsList);
    }

    public <T extends org.nd4j.evaluation.regression.RegressionEvaluation> T evaluateRegression(JavaRDD<DataSet> data) {
        return this.evaluateRegression(data, 64);
    }

    public <T extends org.nd4j.evaluation.regression.RegressionEvaluation> T evaluateRegression(JavaRDD<DataSet> data, int minibatchSize) {
        long nOut = ((FeedForwardLayer)this.network.getOutputLayer(0).conf().getLayer()).getNOut();
        return (T)((org.nd4j.evaluation.regression.RegressionEvaluation)this.doEvaluation(data, new RegressionEvaluation(nOut), minibatchSize));
    }

    public <T extends Evaluation> T evaluate(JavaRDD<DataSet> data, List<String> labelsList) {
        return this.evaluate(data, labelsList, 64);
    }

    public <T extends org.nd4j.evaluation.classification.ROC> T evaluateROC(JavaRDD<DataSet> data) {
        return this.evaluateROC(data, 32, 64);
    }

    public <T extends org.nd4j.evaluation.classification.ROC> T evaluateROC(JavaRDD<DataSet> data, int thresholdSteps, int evaluationMinibatchSize) {
        return (T)((org.nd4j.evaluation.classification.ROC)this.doEvaluation(data, new ROC(thresholdSteps), evaluationMinibatchSize));
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(JavaRDD<DataSet> data) {
        return this.evaluateROCMultiClass(data, 32, 64);
    }

    public <T extends ROCMultiClass> T evaluateROCMultiClass(JavaRDD<DataSet> data, int thresholdSteps, int evaluationMinibatchSize) {
        return (T)((ROCMultiClass)this.doEvaluation(data, new org.deeplearning4j.eval.ROCMultiClass(thresholdSteps), evaluationMinibatchSize));
    }

    public <T extends Evaluation> T evaluate(JavaRDD<DataSet> data, List<String> labelsList, int evalBatchSize) {
        org.deeplearning4j.eval.Evaluation e = new org.deeplearning4j.eval.Evaluation();
        e = (Evaluation)this.doEvaluation(data, e, evalBatchSize);
        if (labelsList != null) {
            e.setLabelsList(labelsList);
        }
        return (T)e;
    }

    public <T extends Evaluation> T evaluateMDS(JavaRDD<MultiDataSet> data) {
        return this.evaluateMDS(data, 64);
    }

    public <T extends Evaluation> T evaluateMDS(JavaRDD<MultiDataSet> data, int minibatchSize) {
        return (T)((org.deeplearning4j.eval.Evaluation[])this.doEvaluationMDS(data, minibatchSize, (IEvaluation[])new org.deeplearning4j.eval.Evaluation[]{new org.deeplearning4j.eval.Evaluation()}))[0];
    }

    public <T extends org.nd4j.evaluation.regression.RegressionEvaluation> T evaluateRegressionMDS(JavaRDD<MultiDataSet> data) {
        return this.evaluateRegressionMDS(data, 64);
    }

    public <T extends org.nd4j.evaluation.regression.RegressionEvaluation> T evaluateRegressionMDS(JavaRDD<MultiDataSet> data, int minibatchSize) {
        return (T)((RegressionEvaluation[])this.doEvaluationMDS(data, minibatchSize, (IEvaluation[])new RegressionEvaluation[]{new RegressionEvaluation()}))[0];
    }

    public org.nd4j.evaluation.classification.ROC evaluateROCMDS(JavaRDD<MultiDataSet> data) {
        return this.evaluateROCMDS(data, 32, 64);
    }

    public <T extends org.nd4j.evaluation.classification.ROC> T evaluateROCMDS(JavaRDD<MultiDataSet> data, int rocThresholdNumSteps, int minibatchSize) {
        return (T)((ROC[])this.doEvaluationMDS(data, minibatchSize, (IEvaluation[])new ROC[]{new ROC(rocThresholdNumSteps)}))[0];
    }

    public <T extends IEvaluation> T doEvaluation(JavaRDD<DataSet> data, T emptyEvaluation, int evalBatchSize) {
        IEvaluation[] arr = new IEvaluation[]{emptyEvaluation};
        return (T)this.doEvaluation(data, evalBatchSize, arr)[0];
    }

    public <T extends IEvaluation> T[] doEvaluation(JavaRDD<DataSet> data, int evalBatchSize, T ... emptyEvaluations) {
        return this.doEvaluation((JavaRDD)data, this.getDefaultEvaluationWorkers(), evalBatchSize, (IEvaluation[])emptyEvaluations);
    }

    public <T extends IEvaluation> T[] doEvaluation(JavaRDD<DataSet> data, int evalNumWorkers, int evalBatchSize, T ... emptyEvaluations) {
        IEvaluateFlatMapFunction evalFn = new IEvaluateFlatMapFunction(true, this.sc.broadcast((Object)this.conf.toJson()), SparkUtils.asByteArrayBroadcast(this.sc, this.network.params()), evalNumWorkers, evalBatchSize, emptyEvaluations);
        JavaRDD evaluations = data.mapPartitions((FlatMapFunction)evalFn);
        return (IEvaluation[])evaluations.treeAggregate(null, new IEvaluateAggregateFunction(), new IEvaluateAggregateFunction());
    }

    public <T extends IEvaluation> T[] doEvaluationMDS(JavaRDD<MultiDataSet> data, int evalBatchSize, T ... emptyEvaluations) {
        return this.doEvaluationMDS((JavaRDD)data, this.getDefaultEvaluationWorkers(), evalBatchSize, (IEvaluation[])emptyEvaluations);
    }

    public <T extends IEvaluation> T[] doEvaluationMDS(JavaRDD<MultiDataSet> data, int evalNumWorkers, int evalBatchSize, T ... emptyEvaluations) {
        Preconditions.checkArgument((evalNumWorkers > 0 ? 1 : 0) != 0, (String)"Invalid number of evaulation workers: require at least 1 - got %s", (int)evalNumWorkers);
        IEvaluateMDSFlatMapFunction evalFn = new IEvaluateMDSFlatMapFunction(this.sc.broadcast((Object)this.conf.toJson()), SparkUtils.asByteArrayBroadcast(this.sc, this.network.params()), evalNumWorkers, evalBatchSize, emptyEvaluations);
        JavaRDD evaluations = data.mapPartitions((FlatMapFunction)evalFn);
        return (IEvaluation[])evaluations.treeAggregate(null, new IEvaluateAggregateFunction(), new IEvaluateAggregateFunction());
    }

    public IEvaluation[] doEvaluation(JavaRDD<String> data, DataSetLoader loader, IEvaluation ... emptyEvaluations) {
        return this.doEvaluation(data, 4, 64, loader, emptyEvaluations);
    }

    public IEvaluation[] doEvaluation(JavaRDD<String> data, int evalNumWorkers, int evalBatchSize, DataSetLoader loader, IEvaluation ... emptyEvaluations) {
        return this.doEvaluation(data, evalNumWorkers, evalBatchSize, loader, (MultiDataSetLoader)null, emptyEvaluations);
    }

    public IEvaluation[] doEvaluation(JavaRDD<String> data, MultiDataSetLoader loader, IEvaluation ... emptyEvaluations) {
        return this.doEvaluation(data, 4, 64, null, loader, emptyEvaluations);
    }

    public IEvaluation[] doEvaluation(JavaRDD<String> data, int evalNumWorkers, int evalBatchSize, MultiDataSetLoader loader, IEvaluation ... emptyEvaluations) {
        return this.doEvaluation(data, evalNumWorkers, evalBatchSize, null, loader, emptyEvaluations);
    }

    protected IEvaluation[] doEvaluation(JavaRDD<String> data, int evalNumWorkers, int evalBatchSize, DataSetLoader loader, MultiDataSetLoader mdsLoader, IEvaluation ... emptyEvaluations) {
        IEvaluateMDSPathsFlatMapFunction evalFn = new IEvaluateMDSPathsFlatMapFunction((Broadcast<String>)this.sc.broadcast((Object)this.conf.toJson()), SparkUtils.asByteArrayBroadcast(this.sc, this.network.params()), evalNumWorkers, evalBatchSize, loader, mdsLoader, (Broadcast<SerializableHadoopConfig>)BroadcastHadoopConfigHolder.get((JavaSparkContext)this.sc), emptyEvaluations);
        Preconditions.checkArgument((evalNumWorkers > 0 ? 1 : 0) != 0, (String)"Invalid number of evaulation workers: require at least 1 - got %s", (int)evalNumWorkers);
        JavaRDD evaluations = data.mapPartitions((FlatMapFunction)evalFn);
        return (IEvaluation[])evaluations.treeAggregate(null, new IEvaluateAggregateFunction(), new IEvaluateAggregateFunction());
    }
}

