/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.evaluation;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Queue;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EvaluationRunner {
    private static final Logger log = LoggerFactory.getLogger(EvaluationRunner.class);
    private static final EvaluationRunner INSTANCE = new EvaluationRunner();
    private final AtomicInteger workerCount = new AtomicInteger(0);
    private Queue<Eval> queue = new ConcurrentLinkedQueue<Eval>();
    private Map<byte[], DeviceLocalNDArray> paramsMap = new WeakHashMap<byte[], DeviceLocalNDArray>();

    public static EvaluationRunner getInstance() {
        return INSTANCE;
    }

    private EvaluationRunner() {
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Future<IEvaluation[]> execute(IEvaluation[] evals, int evalWorkers, int evalBatchSize, Iterator<DataSet> ds, Iterator<MultiDataSet> mds, boolean isCG, Broadcast<String> json, Broadcast<byte[]> params) {
        EvaluationFuture f;
        int currentWorkerCount;
        DeviceLocalNDArray deviceLocalParams;
        Preconditions.checkArgument((evalWorkers > 0 ? 1 : 0) != 0, (String)"Invalid number of evaluation workers: must be > 0. Got: %s", (int)evalWorkers);
        Preconditions.checkState((ds != null || mds != null ? 1 : 0) != 0, (String)"No data provided - both DataSet and MultiDataSet iterators were null");
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        if (numDevices <= 0) {
            numDevices = 1;
        }
        EvaluationRunner evaluationRunner = this;
        synchronized (evaluationRunner) {
            if (!this.paramsMap.containsKey(params.getValue())) {
                INDArray p;
                byte[] pBytes = (byte[])params.getValue();
                try {
                    p = Nd4j.read((InputStream)new ByteArrayInputStream(pBytes));
                }
                catch (RuntimeException e) {
                    throw new RuntimeException(e);
                }
                DeviceLocalNDArray dlp = new DeviceLocalNDArray(p);
                this.paramsMap.put((byte[])params.getValue(), dlp);
            }
            deviceLocalParams = this.paramsMap.get(params.getValue());
        }
        while ((currentWorkerCount = this.workerCount.get()) < evalWorkers) {
            if (!this.workerCount.compareAndSet(currentWorkerCount, currentWorkerCount + 1)) continue;
            log.debug("Starting evaluation in thread {}", (Object)Thread.currentThread().getId());
            f = new EvaluationFuture();
            f.setResult(evals);
            try {
                ComputationGraph m;
                ComputationGraphConfiguration conf;
                if (isCG) {
                    conf = ComputationGraphConfiguration.fromJson((String)((String)json.getValue()));
                    ComputationGraph cg = new ComputationGraph(conf);
                    cg.init(deviceLocalParams.get(), false);
                    m = cg;
                } else {
                    conf = MultiLayerConfiguration.fromJson((String)((String)json.getValue()));
                    MultiLayerNetwork net = new MultiLayerNetwork((MultiLayerConfiguration)conf);
                    net.init(deviceLocalParams.get(), false);
                    m = net;
                }
                try {
                    EvaluationRunner.doEval((Model)m, evals, ds, mds, evalBatchSize);
                }
                catch (Throwable t) {
                    f.setException(t);
                }
                finally {
                    f.getSemaphore().release(1);
                }
                while (!this.queue.isEmpty()) {
                    Eval e = this.queue.poll();
                    if (e == null) continue;
                    try {
                        EvaluationRunner.doEval((Model)m, evals, e.getDs(), e.getMds(), evalBatchSize);
                    }
                    catch (Throwable t) {
                        e.getFuture().setException(t);
                    }
                    finally {
                        e.getFuture().getSemaphore().release(1);
                    }
                }
            }
            finally {
                this.workerCount.decrementAndGet();
                log.debug("Finished evaluation in thread {}", (Object)Thread.currentThread().getId());
            }
            Nd4j.getExecutioner().commit();
            return f;
        }
        log.debug("Submitting evaluation from thread {} for processing in evaluation thread", (Object)Thread.currentThread().getId());
        f = new EvaluationFuture();
        this.queue.add(new Eval(ds, mds, evals, f));
        return f;
    }

    private static void doEval(Model m, IEvaluation[] e, Iterator<DataSet> ds, Iterator<MultiDataSet> mds, int evalBatchSize) {
        if (m instanceof MultiLayerNetwork) {
            MultiLayerNetwork mln = (MultiLayerNetwork)m;
            if (ds != null) {
                mln.doEvaluation((DataSetIterator)new IteratorDataSetIterator(ds, evalBatchSize), e);
            } else {
                mln.doEvaluation((MultiDataSetIterator)new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
            }
        } else {
            ComputationGraph cg = (ComputationGraph)m;
            if (ds != null) {
                cg.doEvaluation((DataSetIterator)new IteratorDataSetIterator(ds, evalBatchSize), e);
            } else {
                cg.doEvaluation((MultiDataSetIterator)new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
            }
        }
    }

    private static class EvaluationFuture
    implements Future<IEvaluation[]> {
        private Semaphore semaphore = new Semaphore(0);
        private IEvaluation[] result;
        private Throwable exception;

        private EvaluationFuture() {
        }

        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            throw new UnsupportedOperationException("Not supported");
        }

        @Override
        public boolean isCancelled() {
            return false;
        }

        @Override
        public boolean isDone() {
            return this.semaphore.availablePermits() > 0;
        }

        @Override
        public IEvaluation[] get() throws InterruptedException, ExecutionException {
            if (this.result == null && this.exception == null) {
                this.semaphore.acquire();
            }
            if (this.exception != null) {
                throw new ExecutionException(this.exception);
            }
            return this.result;
        }

        @Override
        public IEvaluation[] get(long timeout, @NonNull TimeUnit unit) {
            if (unit == null) {
                throw new NullPointerException("unit is marked @NonNull but is null");
            }
            throw new UnsupportedOperationException();
        }

        public void setSemaphore(Semaphore semaphore) {
            this.semaphore = semaphore;
        }

        public void setResult(IEvaluation[] result) {
            this.result = result;
        }

        public void setException(Throwable exception) {
            this.exception = exception;
        }

        public Semaphore getSemaphore() {
            return this.semaphore;
        }

        public IEvaluation[] getResult() {
            return this.result;
        }

        public Throwable getException() {
            return this.exception;
        }
    }

    private static class Eval {
        private Iterator<DataSet> ds;
        private Iterator<MultiDataSet> mds;
        private IEvaluation[] evaluations;
        private EvaluationFuture future;

        public Eval(Iterator<DataSet> ds, Iterator<MultiDataSet> mds, IEvaluation[] evaluations, EvaluationFuture future) {
            this.ds = ds;
            this.mds = mds;
            this.evaluations = evaluations;
            this.future = future;
        }

        public Iterator<DataSet> getDs() {
            return this.ds;
        }

        public Iterator<MultiDataSet> getMds() {
            return this.mds;
        }

        public IEvaluation[] getEvaluations() {
            return this.evaluations;
        }

        public EvaluationFuture getFuture() {
            return this.future;
        }

        public void setDs(Iterator<DataSet> ds) {
            this.ds = ds;
        }

        public void setMds(Iterator<MultiDataSet> mds) {
            this.mds = mds;
        }

        public void setEvaluations(IEvaluation[] evaluations) {
            this.evaluations = evaluations;
        }

        public void setFuture(EvaluationFuture future) {
            this.future = future;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Eval)) {
                return false;
            }
            Eval other = (Eval)o;
            if (!other.canEqual(this)) {
                return false;
            }
            Iterator<DataSet> this$ds = this.getDs();
            Iterator<DataSet> other$ds = other.getDs();
            if (this$ds == null ? other$ds != null : !this$ds.equals(other$ds)) {
                return false;
            }
            Iterator<MultiDataSet> this$mds = this.getMds();
            Iterator<MultiDataSet> other$mds = other.getMds();
            if (this$mds == null ? other$mds != null : !this$mds.equals(other$mds)) {
                return false;
            }
            if (!Arrays.deepEquals(this.getEvaluations(), other.getEvaluations())) {
                return false;
            }
            EvaluationFuture this$future = this.getFuture();
            EvaluationFuture other$future = other.getFuture();
            return !(this$future == null ? other$future != null : !this$future.equals(other$future));
        }

        protected boolean canEqual(Object other) {
            return other instanceof Eval;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Iterator<DataSet> $ds = this.getDs();
            result = result * 59 + ($ds == null ? 43 : $ds.hashCode());
            Iterator<MultiDataSet> $mds = this.getMds();
            result = result * 59 + ($mds == null ? 43 : $mds.hashCode());
            result = result * 59 + Arrays.deepHashCode(this.getEvaluations());
            EvaluationFuture $future = this.getFuture();
            result = result * 59 + ($future == null ? 43 : $future.hashCode());
            return result;
        }

        public String toString() {
            return "EvaluationRunner.Eval(ds=" + this.getDs() + ", mds=" + this.getMds() + ", evaluations=" + Arrays.deepToString(this.getEvaluations()) + ", future=" + this.getFuture() + ")";
        }
    }
}

