/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.api.worker;

import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.stats.StatsCalculationHelper;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class ExecuteWorkerFlatMap<R extends TrainingResult>
implements FlatMapFunction<Iterator<org.nd4j.linalg.dataset.DataSet>, R> {
    private final TrainingWorker<R> worker;

    public ExecuteWorkerFlatMap(TrainingWorker<R> worker) {
        this.worker = worker;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Iterator<R> call(Iterator<org.nd4j.linalg.dataset.DataSet> dataSetIterator) throws Exception {
        StatsCalculationHelper s;
        WorkerConfiguration dataConfig = this.worker.getDataConfiguration();
        boolean isGraph = dataConfig.isGraphNetwork();
        boolean stats = dataConfig.isCollectTrainingStats();
        StatsCalculationHelper statsCalculationHelper = s = stats ? new StatsCalculationHelper() : null;
        if (stats) {
            s.logMethodStartTime();
        }
        if (!dataSetIterator.hasNext()) {
            if (stats) {
                s.logReturnTime();
                Pair<R, SparkTrainingStats> pair = this.worker.getFinalResultNoDataWithStats();
                ((TrainingResult)pair.getFirst()).setStats(s.build((SparkTrainingStats)pair.getSecond()));
                return Collections.singletonList(pair.getFirst()).iterator();
            }
            return Collections.singletonList(this.worker.getFinalResultNoData()).iterator();
        }
        int batchSize = dataConfig.getBatchSizePerWorker();
        int prefetchCount = dataConfig.getPrefetchNumBatches();
        IteratorDataSetIterator batchedIterator = new IteratorDataSetIterator(dataSetIterator, batchSize);
        if (prefetchCount > 0) {
            batchedIterator = new AsyncDataSetIterator((DataSetIterator)batchedIterator, prefetchCount);
        }
        try {
            Iterator<R> iterator;
            int maxMinibatches;
            MultiLayerNetwork net = null;
            ComputationGraph graph = null;
            if (stats) {
                s.logInitialModelBefore();
            }
            if (isGraph) {
                graph = this.worker.getInitialModelGraph();
            } else {
                net = this.worker.getInitialModel();
            }
            if (stats) {
                s.logInitialModelAfter();
            }
            int miniBatchCount = 0;
            int n = maxMinibatches = dataConfig.getMaxBatchesPerWorker() > 0 ? dataConfig.getMaxBatchesPerWorker() : Integer.MAX_VALUE;
            while (batchedIterator.hasNext() && miniBatchCount++ < maxMinibatches) {
                Object result;
                if (stats) {
                    s.logNextDataSetBefore();
                }
                org.nd4j.linalg.dataset.DataSet next = (org.nd4j.linalg.dataset.DataSet)batchedIterator.next();
                if (stats) {
                    s.logNextDataSetAfter(next.numExamples());
                }
                if (stats) {
                    s.logProcessMinibatchBefore();
                    result = isGraph ? this.worker.processMinibatchWithStats((DataSet)next, graph, !batchedIterator.hasNext()) : this.worker.processMinibatchWithStats((DataSet)next, net, !batchedIterator.hasNext());
                    s.logProcessMinibatchAfter();
                    if (result == null) continue;
                    s.logReturnTime();
                    SparkTrainingStats workerStats = (SparkTrainingStats)result.getSecond();
                    CommonSparkTrainingStats returnStats = s.build(workerStats);
                    ((TrainingResult)result.getFirst()).setStats(returnStats);
                    Iterator<Object> iterator2 = Collections.singletonList(result.getFirst()).iterator();
                    return iterator2;
                }
                result = isGraph ? this.worker.processMinibatch((DataSet)next, graph, !batchedIterator.hasNext()) : this.worker.processMinibatch((DataSet)next, net, !batchedIterator.hasNext());
                if (result == null) continue;
                Iterator<R> iterator3 = Collections.singletonList(result).iterator();
                return iterator3;
            }
            if (stats) {
                s.logReturnTime();
                Pair<R, SparkTrainingStats> pair = isGraph ? this.worker.getFinalResultWithStats(graph) : this.worker.getFinalResultWithStats(net);
                ((TrainingResult)pair.getFirst()).setStats(s.build((SparkTrainingStats)pair.getSecond()));
                Iterator<Object> iterator4 = Collections.singletonList(pair.getFirst()).iterator();
                return iterator4;
            }
            if (isGraph) {
                iterator = Collections.singletonList(this.worker.getFinalResult(graph)).iterator();
                return iterator;
            }
            iterator = Collections.singletonList(this.worker.getFinalResult(net)).iterator();
            return iterator;
        }
        finally {
            Nd4j.getExecutioner().commit();
            if (batchedIterator instanceof AsyncDataSetIterator) {
                ((AsyncDataSetIterator)batchedIterator).shutdown();
            }
        }
    }
}

