/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.debugging;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public class OpBenchmarkListener
extends BaseListener {
    private final Operation operation;
    private final Mode mode;
    private final long minRuntime;
    private Map<String, OpExec> aggregateModeMap;
    private long start;
    private boolean printActive;
    private boolean printDone;

    public OpBenchmarkListener(Operation operation, @NonNull Mode mode) {
        this(operation, mode, 0L);
        if (mode == null) {
            throw new NullPointerException("mode is marked @NonNull but is null");
        }
    }

    public OpBenchmarkListener(Operation operation, @NonNull Mode mode, long minRuntime) {
        if (mode == null) {
            throw new NullPointerException("mode is marked @NonNull but is null");
        }
        this.operation = operation;
        this.mode = mode;
        this.minRuntime = minRuntime;
    }

    @Override
    public boolean isActive(Operation operation) {
        return this.operation == null || this.operation == operation;
    }

    @Override
    public void operationStart(SameDiff sd, Operation op) {
        if (this.printDone) {
            return;
        }
        if (this.operation == null || this.operation == op) {
            this.printActive = true;
        }
    }

    @Override
    public void operationEnd(SameDiff sd, Operation op) {
        if (this.printDone) {
            return;
        }
        if (this.operation == null || this.operation == op) {
            this.printActive = false;
            this.printDone = true;
        }
    }

    @Override
    public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
        this.start = System.currentTimeMillis();
    }

    @Override
    public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
        long now = System.currentTimeMillis();
        if (this.mode == Mode.SINGLE_ITER_PRINT && this.printActive && now - this.start > this.minRuntime) {
            System.out.println(this.getOpString(op, now));
        } else if (this.mode == Mode.AGGREGATE) {
            if (this.aggregateModeMap == null) {
                this.aggregateModeMap = new LinkedHashMap<String, OpExec>();
            }
            if (!this.aggregateModeMap.containsKey(op.getName())) {
                String s = this.getOpString(op, null);
                OpExec oe = new OpExec(op.getName(), op.getOp().opName(), op.getOp().getClass(), new ArrayList<Long>(), s);
                this.aggregateModeMap.put(op.getName(), oe);
            }
            this.aggregateModeMap.get(op.getName()).getRuntimeMs().add(now - this.start);
        }
    }

    private String getOpString(SameDiffOp op, Long now) {
        StringBuilder sb = new StringBuilder();
        sb.append(op.getName()).append(" - ").append(op.getOp().getClass().getSimpleName()).append("(").append(op.getOp().opName()).append(") - ");
        if (now != null) {
            sb.append(now - this.start).append(" ms\n");
        }
        if (op.getOp() instanceof DynamicCustomOp) {
            DynamicCustomOp dco = (DynamicCustomOp)op.getOp();
            int x = 0;
            for (INDArray i : dco.inputArguments()) {
                sb.append("  in ").append(x++).append(": ").append(i.shapeInfoToString()).append("\n");
            }
            x = 0;
            for (INDArray o : dco.outputArguments()) {
                sb.append("  out ").append(x++).append(": ").append(o.shapeInfoToString()).append("\n");
            }
            long[] iargs = dco.iArgs();
            boolean[] bargs = dco.bArgs();
            double[] targs = dco.tArgs();
            if (iargs != null && iargs.length > 0) {
                sb.append("  iargs: ").append(Arrays.toString(iargs)).append("\n");
            }
            if (bargs != null && bargs.length > 0) {
                sb.append("  bargs: ").append(Arrays.toString(bargs)).append("\n");
            }
            if (targs != null && targs.length > 0) {
                sb.append("  targs: ").append(Arrays.toString(targs)).append("\n");
            }
        } else {
            Op o = (Op)((Object)op.getOp());
            if (o.x() != null) {
                sb.append("  x: ").append(o.x().shapeInfoToString());
            }
            if (o.y() != null) {
                sb.append("  y: ").append(o.y().shapeInfoToString());
            }
            if (o.z() != null) {
                sb.append("  z: ").append(o.z().shapeInfoToString());
            }
        }
        return sb.toString();
    }

    public Operation getOperation() {
        return this.operation;
    }

    public Mode getMode() {
        return this.mode;
    }

    public long getMinRuntime() {
        return this.minRuntime;
    }

    public Map<String, OpExec> getAggregateModeMap() {
        return this.aggregateModeMap;
    }

    public boolean isPrintDone() {
        return this.printDone;
    }

    private long getStart() {
        return this.start;
    }

    private boolean isPrintActive() {
        return this.printActive;
    }

    public static class OpExec {
        private final String opOwnName;
        private final String opName;
        private final Class<?> opClass;
        private List<Long> runtimeMs;
        private String firstIter;

        public String toString() {
            DecimalFormat df = new DecimalFormat("0.000");
            return this.opOwnName + " - op class: " + this.opClass.getSimpleName() + " (op name: " + this.opName + ")\ncount: " + this.runtimeMs.size() + ", mean: " + df.format(this.avgMs()) + "ms, std: " + df.format(this.stdMs()) + "ms, min: " + this.minMs() + "ms, max: " + this.maxMs() + "ms\n" + this.firstIter;
        }

        public double avgMs() {
            long sum = 0L;
            for (Long l : this.runtimeMs) {
                sum += l.longValue();
            }
            return (double)sum / (double)this.runtimeMs.size();
        }

        public double stdMs() {
            return Nd4j.createFromArray(ArrayUtil.toArrayLong(this.runtimeMs)).stdNumber().doubleValue();
        }

        public long minMs() {
            return Nd4j.createFromArray(ArrayUtil.toArrayLong(this.runtimeMs)).minNumber().longValue();
        }

        public long maxMs() {
            return Nd4j.createFromArray(ArrayUtil.toArrayLong(this.runtimeMs)).maxNumber().longValue();
        }

        public OpExec(String opOwnName, String opName, Class<?> opClass, List<Long> runtimeMs, String firstIter) {
            this.opOwnName = opOwnName;
            this.opName = opName;
            this.opClass = opClass;
            this.runtimeMs = runtimeMs;
            this.firstIter = firstIter;
        }

        public String getOpOwnName() {
            return this.opOwnName;
        }

        public String getOpName() {
            return this.opName;
        }

        public Class<?> getOpClass() {
            return this.opClass;
        }

        public List<Long> getRuntimeMs() {
            return this.runtimeMs;
        }

        public String getFirstIter() {
            return this.firstIter;
        }

        public void setRuntimeMs(List<Long> runtimeMs) {
            this.runtimeMs = runtimeMs;
        }

        public void setFirstIter(String firstIter) {
            this.firstIter = firstIter;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof OpExec)) {
                return false;
            }
            OpExec other = (OpExec)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$opOwnName = this.getOpOwnName();
            String other$opOwnName = other.getOpOwnName();
            if (this$opOwnName == null ? other$opOwnName != null : !this$opOwnName.equals(other$opOwnName)) {
                return false;
            }
            String this$opName = this.getOpName();
            String other$opName = other.getOpName();
            if (this$opName == null ? other$opName != null : !this$opName.equals(other$opName)) {
                return false;
            }
            Class<?> this$opClass = this.getOpClass();
            Class<?> other$opClass = other.getOpClass();
            if (this$opClass == null ? other$opClass != null : !this$opClass.equals(other$opClass)) {
                return false;
            }
            List<Long> this$runtimeMs = this.getRuntimeMs();
            List<Long> other$runtimeMs = other.getRuntimeMs();
            if (this$runtimeMs == null ? other$runtimeMs != null : !((Object)this$runtimeMs).equals(other$runtimeMs)) {
                return false;
            }
            String this$firstIter = this.getFirstIter();
            String other$firstIter = other.getFirstIter();
            return !(this$firstIter == null ? other$firstIter != null : !this$firstIter.equals(other$firstIter));
        }

        protected boolean canEqual(Object other) {
            return other instanceof OpExec;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $opOwnName = this.getOpOwnName();
            result = result * 59 + ($opOwnName == null ? 43 : $opOwnName.hashCode());
            String $opName = this.getOpName();
            result = result * 59 + ($opName == null ? 43 : $opName.hashCode());
            Class<?> $opClass = this.getOpClass();
            result = result * 59 + ($opClass == null ? 43 : $opClass.hashCode());
            List<Long> $runtimeMs = this.getRuntimeMs();
            result = result * 59 + ($runtimeMs == null ? 43 : ((Object)$runtimeMs).hashCode());
            String $firstIter = this.getFirstIter();
            result = result * 59 + ($firstIter == null ? 43 : $firstIter.hashCode());
            return result;
        }
    }

    public static enum Mode {
        SINGLE_ITER_PRINT,
        AGGREGATE;

    }
}

