/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.data;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.nd4j.linalg.dataset.DataSet;

public class BatchDataSetsFunction
implements FlatMapFunction<Iterator<DataSet>, DataSet> {
    private final int minibatchSize;

    public Iterator<DataSet> call(Iterator<DataSet> iter) throws Exception {
        ArrayList<DataSet> out = new ArrayList<DataSet>();
        while (iter.hasNext()) {
            ArrayList<DataSet> list = new ArrayList<DataSet>();
            int count = 0;
            while (count < this.minibatchSize && iter.hasNext()) {
                DataSet ds = iter.next();
                count = (int)((long)count + ds.getFeatures().size(0));
                list.add(ds);
            }
            DataSet next = list.isEmpty() ? (DataSet)list.get(0) : DataSet.merge(list);
            out.add(next);
        }
        return out.iterator();
    }

    public BatchDataSetsFunction(int minibatchSize) {
        this.minibatchSize = minibatchSize;
    }
}

