package cc.mallet.fst;

import cc.mallet.fst.TransducerTrainer;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.InstanceList;
import cc.mallet.util.MalletLogger;
import java.util.Random;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/fst/CRFTrainerByThreadedLabelLikelihood.class */
public class CRFTrainerByThreadedLabelLikelihood extends TransducerTrainer implements TransducerTrainer.ByOptimization {
    private static Logger logger;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0d;
    private int numThreads;
    private CRF crf;
    private CRFOptimizableByBatchLabelLikelihood optimizable;
    private ThreadedOptimizable threadedOptimizable;
    private Optimizer optimizer;
    static final /* synthetic */ boolean $assertionsDisabled;
    private boolean useSparseWeights = true;
    private boolean useNoWeights = false;
    private transient boolean useSomeUnsupportedTrick = true;
    private boolean converged = false;
    private int iterationCount = 0;
    private double gaussianPriorVariance = 1.0d;
    private int cachedWeightsStructureStamp = -1;

    static {
        $assertionsDisabled = !CRFTrainerByThreadedLabelLikelihood.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(CRFTrainerByThreadedLabelLikelihood.class.getName());
    }

    public CRFTrainerByThreadedLabelLikelihood(CRF crf, int i) {
        this.crf = crf;
        this.numThreads = i;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public Transducer getTransducer() {
        return this.crf;
    }

    public CRF getCRF() {
        return this.crf;
    }

    @Override // cc.mallet.fst.TransducerTrainer.ByOptimization
    public Optimizer getOptimizer() {
        return this.optimizer;
    }

    public boolean isConverged() {
        return this.converged;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean isFinishedTraining() {
        return this.converged;
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public int getIteration() {
        return this.iterationCount;
    }

    public void setGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public double getGaussianPriorVariance() {
        return this.gaussianPriorVariance;
    }

    public void setUseSparseWeights(boolean z) {
        this.useSparseWeights = z;
    }

    public boolean getUseSparseWeights() {
        return this.useSparseWeights;
    }

    public void setUseSomeUnsupportedTrick(boolean z) {
        this.useSomeUnsupportedTrick = z;
    }

    public void setAddNoFactors(boolean z) {
        this.useNoWeights = z;
    }

    public void shutdown() {
        this.threadedOptimizable.shutdown();
    }

    public CRFOptimizableByBatchLabelLikelihood getOptimizableCRF(InstanceList instanceList) {
        if (this.cachedWeightsStructureStamp != this.crf.weightsStructureChangeStamp) {
            if (!this.useNoWeights) {
                if (this.useSparseWeights) {
                    this.crf.setWeightsDimensionAsIn(instanceList, this.useSomeUnsupportedTrick);
                } else {
                    this.crf.setWeightsDimensionDensely();
                }
            }
            this.optimizable = null;
            this.cachedWeightsStructureStamp = this.crf.weightsStructureChangeStamp;
        }
        if (this.optimizable == null || this.optimizable.trainingSet != instanceList) {
            this.optimizable = new CRFOptimizableByBatchLabelLikelihood(this.crf, instanceList, this.numThreads);
            this.optimizable.setGaussianPriorVariance(this.gaussianPriorVariance);
            this.threadedOptimizable = new ThreadedOptimizable(this.optimizable, instanceList, this.crf.getParameters().getNumFactors(), new CRFCacheStaleIndicator(this.crf));
            this.optimizer = null;
        }
        return this.optimizable;
    }

    public Optimizer getOptimizer(InstanceList instanceList) {
        getOptimizableCRF(instanceList);
        if (this.optimizer == null || this.optimizable != this.optimizer.getOptimizable()) {
            this.optimizer = new LimitedMemoryBFGS(this.threadedOptimizable);
        }
        return this.optimizer;
    }

    public boolean trainIncremental(InstanceList instanceList) {
        return train(instanceList, Integer.MAX_VALUE);
    }

    @Override // cc.mallet.fst.TransducerTrainer
    public boolean train(InstanceList instanceList, int i) {
        if (i <= 0) {
            return false;
        }
        if (!$assertionsDisabled && instanceList.size() <= 0) {
            throw new AssertionError();
        }
        getOptimizableCRF(instanceList);
        getOptimizer(instanceList);
        boolean z = false;
        logger.info("CRF about to train with " + i + " iterations");
        int i2 = 0;
        while (true) {
            if (i2 >= i) {
                break;
            }
            try {
                z = this.optimizer.optimize(1);
                this.iterationCount++;
                logger.info("CRF finished one iteration of maximizer, i=" + i2);
                runEvaluators();
            } catch (IllegalArgumentException e) {
                e.printStackTrace();
                logger.info("Catching exception; saying converged.");
                z = true;
            } catch (Exception e2) {
                e2.printStackTrace();
                logger.info("Catching exception; saying converged.");
                z = true;
            }
            if (z) {
                logger.info("CRF training has converged, i=" + i2);
                break;
            }
            i2++;
        }
        return z;
    }

    public boolean train(InstanceList instanceList, int i, double[] dArr) {
        int i2 = 0;
        if (!$assertionsDisabled && dArr.length <= 0) {
            throw new AssertionError();
        }
        boolean z = false;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (!$assertionsDisabled && dArr[i3] > 1.0d) {
                throw new AssertionError();
            }
            logger.info("Training on " + dArr[i3] + "% of the data this round.");
            z = dArr[i3] == 1.0d ? train(instanceList, i) : train(instanceList.split(new Random(1L), new double[]{dArr[i3], 1.0d - dArr[i3]})[0], i);
            i2 += i;
        }
        return z;
    }
}
