/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.common.repartition;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.apache.spark.Partitioner;
import org.nd4j.shade.guava.base.Preconditions;
import scala.Tuple2;

public class HashingBalancedPartitioner
extends Partitioner {
    private final int numClasses;
    private final int numPartitions;
    private List<List<Double>> partitionWeightsByClass;
    private List<List<Double>> jumpTable;
    private Random r;

    public HashingBalancedPartitioner(List<List<Double>> partitionWeightsByClass) {
        List pw = (List)Preconditions.checkNotNull(partitionWeightsByClass);
        Preconditions.checkArgument((!pw.isEmpty() ? 1 : 0) != 0, (Object)"Partition weights are required");
        Preconditions.checkArgument((pw.size() >= 1 ? 1 : 0) != 0, (Object)"There should be at least one element class");
        Preconditions.checkArgument((!((List)Preconditions.checkNotNull(pw.get(0))).isEmpty() ? 1 : 0) != 0, (Object)"At least one partition is required");
        this.numClasses = pw.size();
        this.numPartitions = ((List)pw.get(0)).size();
        for (int i = 1; i < pw.size(); ++i) {
            Preconditions.checkArgument((((List)Preconditions.checkNotNull(pw.get(i))).size() == this.numPartitions ? 1 : 0) != 0, (Object)"Non-consistent partition weight specification");
        }
        this.partitionWeightsByClass = partitionWeightsByClass;
        ArrayList<List<Double>> jumpsByClass = new ArrayList<List<Double>>();
        for (int j = 0; j < this.numClasses; ++j) {
            Double totalImbalance = 0.0;
            for (int i = 0; i < this.numPartitions; ++i) {
                totalImbalance = totalImbalance + (partitionWeightsByClass.get(j).get(i) >= 0.0 ? Math.max(1.0 - partitionWeightsByClass.get(j).get(i), 0.0) : 0.0);
            }
            Double sumProb = 0.0;
            ArrayList<Double> cumulProbsThisClass = new ArrayList<Double>();
            for (int i = 0; i < this.numPartitions; ++i) {
                if (partitionWeightsByClass.get(j).get(i) >= 0.0 && (totalImbalance > 0.0 || sumProb >= 1.0)) {
                    Double thisPartitionRelProb = Math.max(1.0 - partitionWeightsByClass.get(j).get(i), 0.0) / totalImbalance;
                    if (thisPartitionRelProb > 0.0) {
                        sumProb = sumProb + thisPartitionRelProb;
                        cumulProbsThisClass.add(sumProb);
                        continue;
                    }
                    cumulProbsThisClass.add(0.0);
                    continue;
                }
                cumulProbsThisClass.add(0.0);
            }
            jumpsByClass.add(cumulProbsThisClass);
        }
        this.jumpTable = jumpsByClass;
    }

    public int numPartitions() {
        List<Double> list = this.partitionWeightsByClass.get(0);
        int count = 0;
        for (Double d : list) {
            if (!(d >= 0.0)) continue;
            ++count;
        }
        return count;
    }

    public int getPartition(Object key) {
        Preconditions.checkArgument((boolean)(key instanceof Tuple2), (Object)"The key should be in the form: Tuple2(SparkUID, class) ...");
        Tuple2 uidNclass = (Tuple2)key;
        Long uid = (Long)uidNclass._1();
        Integer partitionId = (int)(uid % (long)this.numPartitions);
        Integer elementClass = (Integer)uidNclass._2();
        Double jumpProbability = Math.max(1.0 - 1.0 / this.partitionWeightsByClass.get(elementClass).get(partitionId), 0.0);
        LinearCongruentialGenerator rand = new LinearCongruentialGenerator(uid);
        Double thisJumps = rand.nextDouble();
        Integer thisPartition = partitionId;
        if (thisJumps < jumpProbability) {
            List<Double> jumpsTo = this.jumpTable.get(elementClass);
            Double destination = rand.nextDouble();
            Integer probe = 0;
            while (jumpsTo.get(probe) < destination) {
                Integer n = probe;
                Integer n2 = probe = Integer.valueOf(probe + 1);
            }
            thisPartition = probe;
        }
        return thisPartition;
    }

    static final class LinearCongruentialGenerator {
        private long state;

        public LinearCongruentialGenerator(long seed) {
            this.state = seed;
        }

        public double nextDouble() {
            this.state = 2862933555777941757L * this.state + 1L;
            return (double)((int)(this.state >>> 33) + 1) / 2.147483648E9;
        }
    }
}

