package org.deeplearning4j.parallelism.main;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import java.io.File;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter;
import org.deeplearning4j.core.util.ModelGuesser;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/main/ParallelWrapperMain.class */
public class ParallelWrapperMain {
    private static final Logger log = LoggerFactory.getLogger(ParallelWrapperMain.class);

    @Parameter(names = {"--modelPath"}, description = "Path to the model", arity = 1, required = true)
    private String modelPath = null;

    @Parameter(names = {"--workers"}, description = "Number of workers", arity = 1)
    private int workers = 2;

    @Parameter(names = {"--prefetchSize"}, description = "The number of datasets to prefetch", arity = 1)
    private int prefetchSize = 16;

    @Parameter(names = {"--averagingFrequency"}, description = "The frequency for averaging parameters", arity = 1)
    private int averagingFrequency = 1;

    @Parameter(names = {"--reportScore"}, description = "The subcommand to run", arity = 1)
    private boolean reportScore = false;

    @Parameter(names = {"--averageUpdaters"}, description = "Whether to average updaters", arity = 1)
    private boolean averageUpdaters = true;

    @Parameter(names = {"--legacyAveraging"}, description = "Whether to use legacy averaging", arity = 1)
    private boolean legacyAveraging = true;

    @Parameter(names = {"--dataSetIteratorFactoryClazz"}, description = "The fully qualified class name of the multi data set iterator class to use.", arity = 1)
    private String dataSetIteratorFactoryClazz = null;

    @Parameter(names = {"--multiDataSetIteratorFactoryClazz"}, description = "The fully qualified class name of the multi data set iterator class to use.", arity = 1)
    private String multiDataSetIteratorFactoryClazz = null;

    @Parameter(names = {"--modelOutputPath"}, description = "The fully qualified class name of the multi data set iterator class to use.", arity = 1, required = true)
    private String modelOutputPath = null;

    @Parameter(names = {"--uiUrl"}, description = "The host:port of the ui to use (optional)", arity = 1)
    private String uiUrl = null;
    private RemoteUIStatsStorageRouter remoteUIRouter;
    private ParallelWrapper wrapper;

    public static void main(String[] strArr) throws Exception {
        new ParallelWrapperMain().runMain(strArr);
    }

    public void runMain(String... strArr) throws Exception {
        JCommander jCommander = new JCommander(this);
        try {
            jCommander.parse(strArr);
        } catch (ParameterException e) {
            System.err.println(e.getMessage());
            jCommander.usage();
            try {
                Thread.sleep(500L);
            } catch (Exception e2) {
            }
            System.exit(1);
        }
        run();
    }

    public void run() throws Exception {
        Model loadModelGuess = ModelGuesser.loadModelGuess(this.modelPath);
        this.wrapper = new ParallelWrapper.Builder(loadModelGuess).prefetchBuffer(this.prefetchSize).workers(this.workers).averagingFrequency(this.averagingFrequency).averageUpdaters(this.averageUpdaters).reportScoreAfterAveraging(this.reportScore).build();
        if (this.dataSetIteratorFactoryClazz != null) {
            DataSetIterator create = ((DataSetIteratorProviderFactory) DL4JClassLoading.createNewInstance(this.dataSetIteratorFactoryClazz)).create();
            if (this.uiUrl != null) {
                this.wrapper.setListeners(new RemoteUIStatsStorageRouter("http://" + this.uiUrl), (TrainingListener) DL4JClassLoading.createNewInstance("org.deeplearning4j.ui.model.stats.StatsListener", StatsStorageRouter.class, new Class[]{StatsStorageRouter.class}, new Object[]{null}));
            }
            this.wrapper.fit(create);
            ModelSerializer.writeModel(loadModelGuess, new File(this.modelOutputPath), true);
            return;
        }
        if (this.multiDataSetIteratorFactoryClazz == null) {
            throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
        }
        MultiDataSetIterator create2 = ((MultiDataSetProviderFactory) DL4JClassLoading.createNewInstance(this.multiDataSetIteratorFactoryClazz)).create();
        if (this.uiUrl != null) {
            this.remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + this.uiUrl);
            this.wrapper.setListeners((StatsStorageRouter) this.remoteUIRouter, (TrainingListener) DL4JClassLoading.createNewInstance("org.deeplearning4j.ui.model.stats.StatsListener", TrainingListener.class, new Class[]{StatsStorageRouter.class}, new Object[]{null}));
        }
        this.wrapper.fit(create2);
        ModelSerializer.writeModel(loadModelGuess, new File(this.modelOutputPath), true);
    }

    public void stop() {
        if (this.remoteUIRouter != null) {
            this.remoteUIRouter.shutdown();
        }
        if (this.wrapper != null) {
            try {
                this.wrapper.close();
            } catch (Throwable th) {
                log.warn("ParallelWrapperMain.close(): Exception encountered trying to close ParallelWrapper instance", th);
                throw new RuntimeException(th);
            }
        }
    }

    public String getModelPath() {
        return this.modelPath;
    }

    public int getWorkers() {
        return this.workers;
    }

    public int getPrefetchSize() {
        return this.prefetchSize;
    }

    public int getAveragingFrequency() {
        return this.averagingFrequency;
    }

    public boolean isReportScore() {
        return this.reportScore;
    }

    public boolean isAverageUpdaters() {
        return this.averageUpdaters;
    }

    public boolean isLegacyAveraging() {
        return this.legacyAveraging;
    }

    public String getDataSetIteratorFactoryClazz() {
        return this.dataSetIteratorFactoryClazz;
    }

    public String getMultiDataSetIteratorFactoryClazz() {
        return this.multiDataSetIteratorFactoryClazz;
    }

    public String getModelOutputPath() {
        return this.modelOutputPath;
    }

    public String getUiUrl() {
        return this.uiUrl;
    }

    public RemoteUIStatsStorageRouter getRemoteUIRouter() {
        return this.remoteUIRouter;
    }

    public ParallelWrapper getWrapper() {
        return this.wrapper;
    }

    public void setModelPath(String str) {
        this.modelPath = str;
    }

    public void setWorkers(int i) {
        this.workers = i;
    }

    public void setPrefetchSize(int i) {
        this.prefetchSize = i;
    }

    public void setAveragingFrequency(int i) {
        this.averagingFrequency = i;
    }

    public void setReportScore(boolean z) {
        this.reportScore = z;
    }

    public void setAverageUpdaters(boolean z) {
        this.averageUpdaters = z;
    }

    public void setLegacyAveraging(boolean z) {
        this.legacyAveraging = z;
    }

    public void setDataSetIteratorFactoryClazz(String str) {
        this.dataSetIteratorFactoryClazz = str;
    }

    public void setMultiDataSetIteratorFactoryClazz(String str) {
        this.multiDataSetIteratorFactoryClazz = str;
    }

    public void setModelOutputPath(String str) {
        this.modelOutputPath = str;
    }

    public void setUiUrl(String str) {
        this.uiUrl = str;
    }

    public void setRemoteUIRouter(RemoteUIStatsStorageRouter remoteUIStatsStorageRouter) {
        this.remoteUIRouter = remoteUIStatsStorageRouter;
    }

    public void setWrapper(ParallelWrapper parallelWrapper) {
        this.wrapper = parallelWrapper;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ParallelWrapperMain)) {
            return false;
        }
        ParallelWrapperMain parallelWrapperMain = (ParallelWrapperMain) obj;
        if (!parallelWrapperMain.canEqual(this) || getWorkers() != parallelWrapperMain.getWorkers() || getPrefetchSize() != parallelWrapperMain.getPrefetchSize() || getAveragingFrequency() != parallelWrapperMain.getAveragingFrequency() || isReportScore() != parallelWrapperMain.isReportScore() || isAverageUpdaters() != parallelWrapperMain.isAverageUpdaters() || isLegacyAveraging() != parallelWrapperMain.isLegacyAveraging()) {
            return false;
        }
        String modelPath = getModelPath();
        String modelPath2 = parallelWrapperMain.getModelPath();
        if (modelPath == null) {
            if (modelPath2 != null) {
                return false;
            }
        } else if (!modelPath.equals(modelPath2)) {
            return false;
        }
        String dataSetIteratorFactoryClazz = getDataSetIteratorFactoryClazz();
        String dataSetIteratorFactoryClazz2 = parallelWrapperMain.getDataSetIteratorFactoryClazz();
        if (dataSetIteratorFactoryClazz == null) {
            if (dataSetIteratorFactoryClazz2 != null) {
                return false;
            }
        } else if (!dataSetIteratorFactoryClazz.equals(dataSetIteratorFactoryClazz2)) {
            return false;
        }
        String multiDataSetIteratorFactoryClazz = getMultiDataSetIteratorFactoryClazz();
        String multiDataSetIteratorFactoryClazz2 = parallelWrapperMain.getMultiDataSetIteratorFactoryClazz();
        if (multiDataSetIteratorFactoryClazz == null) {
            if (multiDataSetIteratorFactoryClazz2 != null) {
                return false;
            }
        } else if (!multiDataSetIteratorFactoryClazz.equals(multiDataSetIteratorFactoryClazz2)) {
            return false;
        }
        String modelOutputPath = getModelOutputPath();
        String modelOutputPath2 = parallelWrapperMain.getModelOutputPath();
        if (modelOutputPath == null) {
            if (modelOutputPath2 != null) {
                return false;
            }
        } else if (!modelOutputPath.equals(modelOutputPath2)) {
            return false;
        }
        String uiUrl = getUiUrl();
        String uiUrl2 = parallelWrapperMain.getUiUrl();
        if (uiUrl == null) {
            if (uiUrl2 != null) {
                return false;
            }
        } else if (!uiUrl.equals(uiUrl2)) {
            return false;
        }
        RemoteUIStatsStorageRouter remoteUIRouter = getRemoteUIRouter();
        RemoteUIStatsStorageRouter remoteUIRouter2 = parallelWrapperMain.getRemoteUIRouter();
        if (remoteUIRouter == null) {
            if (remoteUIRouter2 != null) {
                return false;
            }
        } else if (!remoteUIRouter.equals(remoteUIRouter2)) {
            return false;
        }
        ParallelWrapper wrapper = getWrapper();
        ParallelWrapper wrapper2 = parallelWrapperMain.getWrapper();
        return wrapper == null ? wrapper2 == null : wrapper.equals(wrapper2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ParallelWrapperMain;
    }

    public int hashCode() {
        int workers = (((((((((((1 * 59) + getWorkers()) * 59) + getPrefetchSize()) * 59) + getAveragingFrequency()) * 59) + (isReportScore() ? 79 : 97)) * 59) + (isAverageUpdaters() ? 79 : 97)) * 59) + (isLegacyAveraging() ? 79 : 97);
        String modelPath = getModelPath();
        int hashCode = (workers * 59) + (modelPath == null ? 43 : modelPath.hashCode());
        String dataSetIteratorFactoryClazz = getDataSetIteratorFactoryClazz();
        int hashCode2 = (hashCode * 59) + (dataSetIteratorFactoryClazz == null ? 43 : dataSetIteratorFactoryClazz.hashCode());
        String multiDataSetIteratorFactoryClazz = getMultiDataSetIteratorFactoryClazz();
        int hashCode3 = (hashCode2 * 59) + (multiDataSetIteratorFactoryClazz == null ? 43 : multiDataSetIteratorFactoryClazz.hashCode());
        String modelOutputPath = getModelOutputPath();
        int hashCode4 = (hashCode3 * 59) + (modelOutputPath == null ? 43 : modelOutputPath.hashCode());
        String uiUrl = getUiUrl();
        int hashCode5 = (hashCode4 * 59) + (uiUrl == null ? 43 : uiUrl.hashCode());
        RemoteUIStatsStorageRouter remoteUIRouter = getRemoteUIRouter();
        int hashCode6 = (hashCode5 * 59) + (remoteUIRouter == null ? 43 : remoteUIRouter.hashCode());
        ParallelWrapper wrapper = getWrapper();
        return (hashCode6 * 59) + (wrapper == null ? 43 : wrapper.hashCode());
    }

    public String toString() {
        return "ParallelWrapperMain(modelPath=" + getModelPath() + ", workers=" + getWorkers() + ", prefetchSize=" + getPrefetchSize() + ", averagingFrequency=" + getAveragingFrequency() + ", reportScore=" + isReportScore() + ", averageUpdaters=" + isAverageUpdaters() + ", legacyAveraging=" + isLegacyAveraging() + ", dataSetIteratorFactoryClazz=" + getDataSetIteratorFactoryClazz() + ", multiDataSetIteratorFactoryClazz=" + getMultiDataSetIteratorFactoryClazz() + ", modelOutputPath=" + getModelOutputPath() + ", uiUrl=" + getUiUrl() + ", remoteUIRouter=" + getRemoteUIRouter() + ", wrapper=" + getWrapper() + ")";
    }
}
