/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.datavec;

import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.function.Function;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.FeatureUtil;
import scala.Tuple2;

public class DataVecSequencePairDataSetFunction
implements Function<Tuple2<List<List<Writable>>, List<List<Writable>>>, org.nd4j.linalg.dataset.DataSet>,
Serializable {
    private final boolean regression;
    private final int numPossibleLabels;
    private final AlignmentMode alignmentMode;
    private final DataSetPreProcessor preProcessor;
    private final WritableConverter converter;

    public DataVecSequencePairDataSetFunction() {
        this(-1, true);
    }

    public DataVecSequencePairDataSetFunction(int numPossibleLabels, boolean regression) {
        this(numPossibleLabels, regression, AlignmentMode.EQUAL_LENGTH);
    }

    public DataVecSequencePairDataSetFunction(int numPossibleLabels, boolean regression, AlignmentMode alignmentMode) {
        this(numPossibleLabels, regression, alignmentMode, null, null);
    }

    public DataVecSequencePairDataSetFunction(int numPossibleLabels, boolean regression, AlignmentMode alignmentMode, DataSetPreProcessor preProcessor, WritableConverter converter) {
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.alignmentMode = alignmentMode;
        this.preProcessor = preProcessor;
        this.converter = converter;
    }

    public org.nd4j.linalg.dataset.DataSet call(Tuple2<List<List<Writable>>, List<List<Writable>>> input) throws Exception {
        INDArray newInput;
        INDArray newOutput;
        org.nd4j.linalg.dataset.DataSet ds;
        Writable current;
        int f;
        Iterator timeStepIter;
        List step;
        List featuresSeq = (List)input._1();
        List labelsSeq = (List)input._2();
        int featuresLength = featuresSeq.size();
        int labelsLength = labelsSeq.size();
        Iterator fIter = featuresSeq.iterator();
        Iterator lIter = labelsSeq.iterator();
        INDArray inputArr = null;
        INDArray outputArr = null;
        int[] idx = new int[3];
        int i = 0;
        while (fIter.hasNext()) {
            step = (List)fIter.next();
            if (i == 0) {
                int[] inShape = new int[]{1, step.size(), featuresLength};
                inputArr = Nd4j.create((int[])inShape);
            }
            timeStepIter = step.iterator();
            f = 0;
            idx[1] = 0;
            while (timeStepIter.hasNext()) {
                current = (Writable)timeStepIter.next();
                if (this.converter != null) {
                    current = this.converter.convert(current);
                }
                try {
                    inputArr.putScalar(idx, current.toDouble());
                }
                catch (UnsupportedOperationException e) {
                    if (current instanceof NDArrayWritable) {
                        inputArr.get(new INDArrayIndex[]{NDArrayIndex.point((long)idx[0]), NDArrayIndex.all(), NDArrayIndex.point((long)idx[2])}).putRow(0L, ((NDArrayWritable)current).get());
                    }
                    throw e;
                }
                idx[1] = ++f;
            }
            idx[2] = ++i;
        }
        idx = new int[3];
        i = 0;
        while (lIter.hasNext()) {
            step = (List)lIter.next();
            if (i == 0) {
                int[] outShape = new int[]{1, this.regression ? step.size() : this.numPossibleLabels, labelsLength};
                outputArr = Nd4j.create((int[])outShape);
            }
            timeStepIter = step.iterator();
            f = 0;
            idx[1] = 0;
            if (this.regression) {
                while (timeStepIter.hasNext()) {
                    current = (Writable)timeStepIter.next();
                    if (this.converter != null) {
                        current = this.converter.convert(current);
                    }
                    outputArr.putScalar(idx, current.toDouble());
                    idx[1] = ++f;
                }
            } else {
                Writable value = (Writable)timeStepIter.next();
                int labelClassIdx = value.toInt();
                INDArray line = FeatureUtil.toOutcomeVector((long)labelClassIdx, (long)this.numPossibleLabels);
                outputArr.tensorAlongDimension((long)i, new int[]{1}).assign(line);
            }
            idx[2] = ++i;
        }
        if (this.alignmentMode == AlignmentMode.EQUAL_LENGTH || featuresLength == labelsLength) {
            ds = new org.nd4j.linalg.dataset.DataSet(inputArr, outputArr);
        } else if (this.alignmentMode == AlignmentMode.ALIGN_END) {
            if (featuresLength > labelsLength) {
                newOutput = Nd4j.create((long[])new long[]{1L, outputArr.size(1), featuresLength});
                newOutput.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.interval((int)(featuresLength - labelsLength), (int)featuresLength)}).assign(outputArr);
                INDArray outputMask = Nd4j.create((int)1, (int)featuresLength);
                for (int j = featuresLength - labelsLength; j < featuresLength; ++j) {
                    outputMask.putScalar((long)j, 1.0);
                }
                ds = new org.nd4j.linalg.dataset.DataSet(inputArr, newOutput, Nd4j.ones((long[])outputMask.shape()), outputMask);
            } else {
                newInput = Nd4j.create((long[])new long[]{1L, inputArr.size(1), labelsLength});
                newInput.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.interval((int)(labelsLength - featuresLength), (int)labelsLength)}).assign(inputArr);
                INDArray inputMask = Nd4j.create((int)1, (int)labelsLength);
                for (int j = labelsLength - featuresLength; j < labelsLength; ++j) {
                    inputMask.putScalar((long)j, 1.0);
                }
                ds = new org.nd4j.linalg.dataset.DataSet(newInput, outputArr, inputMask, Nd4j.ones((long[])inputMask.shape()));
            }
        } else if (this.alignmentMode == AlignmentMode.ALIGN_START) {
            if (featuresLength > labelsLength) {
                newOutput = Nd4j.create((long[])new long[]{1L, outputArr.size(1), featuresLength});
                newOutput.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)labelsLength)}).assign(outputArr);
                INDArray outputMask = Nd4j.create((int)1, (int)featuresLength);
                for (int j = 0; j < labelsLength; ++j) {
                    outputMask.putScalar((long)j, 1.0);
                }
                ds = new org.nd4j.linalg.dataset.DataSet(inputArr, newOutput, Nd4j.ones((long[])outputMask.shape()), outputMask);
            } else {
                newInput = Nd4j.create((long[])new long[]{1L, inputArr.size(1), labelsLength});
                newInput.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)featuresLength)}).assign(inputArr);
                INDArray inputMask = Nd4j.create((int)1, (int)labelsLength);
                for (int j = 0; j < featuresLength; ++j) {
                    inputMask.putScalar((long)j, 1.0);
                }
                ds = new org.nd4j.linalg.dataset.DataSet(newInput, outputArr, inputMask, Nd4j.ones((long[])inputMask.shape()));
            }
        } else {
            throw new UnsupportedOperationException("Invalid alignment mode: " + (Object)((Object)this.alignmentMode));
        }
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ds);
        }
        return ds;
    }

    public static enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END;

    }
}

