/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.iterator;

import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContextHelper;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.dataset.AsyncDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.callbacks.DataSetCallback;
import org.nd4j.linalg.dataset.callbacks.DefaultCallback;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SparkADSI
extends AsyncDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(SparkADSI.class);
    protected TaskContext context;

    protected SparkADSI() {
    }

    public SparkADSI(DataSetIterator baseIterator) {
        this(baseIterator, 8);
    }

    public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue) {
        this(iterator, queueSize, queue, true);
    }

    public SparkADSI(DataSetIterator baseIterator, int queueSize) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize));
    }

    public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace);
    }

    public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, (DataSetCallback)new DefaultCallback(), deviceId);
    }

    public SparkADSI(DataSetIterator baseIterator, int queueSize, boolean useWorkspace, DataSetCallback callback) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<DataSet>(queueSize), useWorkspace, callback);
    }

    public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace) {
        this(iterator, queueSize, queue, useWorkspace, (DataSetCallback)new DefaultCallback());
    }

    public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, DataSetCallback callback) {
        this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
    }

    public SparkADSI(DataSetIterator iterator, int queueSize, BlockingQueue<DataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
        this();
        if (queueSize < 2) {
            queueSize = 2;
        }
        this.deviceId = deviceId;
        this.callback = callback;
        this.useWorkspace = useWorkspace;
        this.buffer = queue;
        this.prefetchSize = queueSize;
        this.backedIterator = iterator;
        this.workspaceId = "SADSI_ITER-" + UUID.randomUUID().toString();
        if (iterator.resetSupported()) {
            this.backedIterator.reset();
        }
        this.context = TaskContext.get();
        this.thread = new SparkPrefetchThread(this.buffer, iterator, this.terminator, null, Nd4j.getAffinityManager().getDeviceForCurrentThread());
        this.thread.setDaemon(true);
        this.thread.start();
    }

    protected void externalCall() {
        TaskContextHelper.setTaskContext(this.context);
    }

    public class SparkPrefetchThread
    extends AsyncDataSetIterator.AsyncPrefetchThread {
        protected SparkPrefetchThread(BlockingQueue<DataSet> queue, DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) {
            super((AsyncDataSetIterator)SparkADSI.this, queue, iterator, terminator, workspace, deviceId);
        }
    }
}

