/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.graph.scoring;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class ScoreExamplesWithKeyFunction<K>
implements PairFlatMapFunction<Iterator<Tuple2<K, org.nd4j.linalg.dataset.api.MultiDataSet>>, K, Double> {
    private static final Logger log = LoggerFactory.getLogger(ScoreExamplesWithKeyFunction.class);
    private final Broadcast<INDArray> params;
    private final Broadcast<String> jsonConfig;
    private final boolean addRegularization;
    private final int batchSize;

    public ScoreExamplesWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, boolean addRegularizationTerms, int batchSize) {
        this.params = params;
        this.jsonConfig = jsonConfig;
        this.addRegularization = addRegularizationTerms;
        this.batchSize = batchSize;
    }

    public Iterator<Tuple2<K, Double>> call(Iterator<Tuple2<K, org.nd4j.linalg.dataset.api.MultiDataSet>> iterator) throws Exception {
        if (!iterator.hasNext()) {
            return Collections.emptyIterator();
        }
        ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson((String)((String)this.jsonConfig.getValue())));
        network.init();
        INDArray val = ((INDArray)this.params.value()).unsafeDuplication();
        if (val.length() != network.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        network.setParams(val);
        ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
        ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet> collect = new ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet>(this.batchSize);
        ArrayList<Object> collectKey = new ArrayList<Object>(this.batchSize);
        int totalCount = 0;
        while (iterator.hasNext()) {
            collect.clear();
            collectKey.clear();
            int nExamples = 0;
            while (iterator.hasNext() && nExamples < this.batchSize) {
                Tuple2<K, org.nd4j.linalg.dataset.api.MultiDataSet> t2 = iterator.next();
                org.nd4j.linalg.dataset.api.MultiDataSet ds = (org.nd4j.linalg.dataset.api.MultiDataSet)t2._2();
                long n = ds.getFeatures(0).size(0);
                if (n != 1L) {
                    throw new IllegalStateException("Cannot score examples with one key per data set if data set contains more than 1 example (numExamples: " + n + ")");
                }
                collect.add(ds);
                collectKey.add(t2._1());
                nExamples = (int)((long)nExamples + n);
            }
            totalCount += nExamples;
            MultiDataSet data = MultiDataSet.merge(collect);
            INDArray scores = network.scoreExamples((org.nd4j.linalg.dataset.api.MultiDataSet)data, this.addRegularization);
            double[] doubleScores = scores.data().asDouble();
            for (int i = 0; i < doubleScores.length; ++i) {
                ret.add(new Tuple2(collectKey.get(i), (Object)doubleScores[i]));
            }
        }
        Nd4j.getExecutioner().commit();
        if (log.isDebugEnabled()) {
            log.debug("Scored {} examples ", (Object)totalCount);
        }
        return ret.iterator();
    }
}

