/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.datavec;

import java.io.Serializable;
import java.util.List;
import org.apache.spark.api.java.function.Function;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.WritableConverterException;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class DataVecDataSetFunction
implements Function<List<Writable>, org.nd4j.linalg.dataset.DataSet>,
Serializable {
    private final int labelIndex;
    private final int labelIndexTo;
    private final int numPossibleLabels;
    private final boolean regression;
    private final DataSetPreProcessor preProcessor;
    private final WritableConverter converter;
    protected int batchSize = -1;

    public DataVecDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression) {
        this(labelIndex, numPossibleLabels, regression, null, null);
    }

    public DataVecDataSetFunction(int labelIndex, int numPossibleLabels, boolean regression, DataSetPreProcessor preProcessor, WritableConverter converter) {
        this(labelIndex, labelIndex, numPossibleLabels, regression, preProcessor, converter);
    }

    public DataVecDataSetFunction(int labelIndexFrom, int labelIndexTo, int numPossibleLabels, boolean regression, DataSetPreProcessor preProcessor, WritableConverter converter) {
        this.labelIndex = labelIndexFrom;
        this.labelIndexTo = labelIndexTo;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.preProcessor = preProcessor;
        this.converter = converter;
    }

    public org.nd4j.linalg.dataset.DataSet call(List<Writable> currList) throws Exception {
        int labelIndex = this.labelIndex;
        if (this.numPossibleLabels >= 1 && labelIndex < 0) {
            labelIndex = currList.size() - 1;
        }
        INDArray label = null;
        INDArray featureVector = null;
        int featureCount = 0;
        int labelCount = 0;
        if (currList.size() == 2 && currList.get(1) instanceof NDArrayWritable && currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) {
            NDArrayWritable writable = (NDArrayWritable)currList.get(0);
            org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(writable.get(), writable.get());
            if (this.preProcessor != null) {
                this.preProcessor.preProcess((DataSet)ds);
            }
            return ds;
        }
        if (currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) {
            label = !this.regression ? FeatureUtil.toOutcomeVector((long)((int)Double.parseDouble(currList.get(1).toString())), (long)this.numPossibleLabels) : Nd4j.scalar((double)Double.parseDouble(currList.get(1).toString())).reshape(1L, 1L);
            NDArrayWritable ndArrayWritable = (NDArrayWritable)currList.get(0);
            featureVector = ndArrayWritable.get();
            org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(featureVector, label);
            if (this.preProcessor != null) {
                this.preProcessor.preProcess((DataSet)ds);
            }
            return ds;
        }
        for (int j = 0; j < currList.size(); ++j) {
            Writable current = currList.get(j);
            if (!(current instanceof NDArrayWritable) && current.toString().isEmpty()) continue;
            if (labelIndex >= 0 && j >= labelIndex && j <= this.labelIndexTo) {
                if (this.converter != null) {
                    try {
                        current = this.converter.convert(current);
                    }
                    catch (WritableConverterException e) {
                        e.printStackTrace();
                    }
                }
                if (this.regression) {
                    if (label == null) {
                        label = Nd4j.zeros((long)1L, (long)(this.labelIndexTo - labelIndex + 1));
                    }
                    label.putScalar(0L, (long)labelCount++, current.toDouble());
                    continue;
                }
                if (this.numPossibleLabels < 1) {
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1 for classification");
                }
                int curr = current.toInt();
                if (curr >= this.numPossibleLabels) {
                    throw new IllegalStateException("Invalid index: got index " + curr + " but numPossibleLabels is " + this.numPossibleLabels + " (must be 0 <= idx < numPossibleLabels");
                }
                label = FeatureUtil.toOutcomeVector((long)curr, (long)this.numPossibleLabels);
                continue;
            }
            try {
                double value = current.toDouble();
                if (featureVector == null) {
                    if (this.regression && labelIndex >= 0) {
                        int nLabels = this.labelIndexTo - labelIndex + 1;
                        featureVector = Nd4j.create((int)1, (int)(currList.size() - nLabels));
                    } else {
                        featureVector = Nd4j.create((int)1, (int)(labelIndex >= 0 ? currList.size() - 1 : currList.size()));
                    }
                }
                featureVector.putScalar((long)featureCount++, value);
                continue;
            }
            catch (UnsupportedOperationException e) {
                if (current instanceof NDArrayWritable) {
                    Preconditions.checkState((featureVector == null ? 1 : 0) != 0, (String)"Already got an array");
                    featureVector = ((NDArrayWritable)current).get();
                    continue;
                }
                throw e;
            }
        }
        org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(featureVector, labelIndex >= 0 ? label : featureVector);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((DataSet)ds);
        }
        return ds;
    }
}

