package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Discretize;
import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;

/* loaded from: input_file:weka/classifiers/mi/MIBoost.class */
public class MIBoost extends SingleClassifierEnhancer implements OptionHandler, MultiInstanceCapabilitiesHandler, TechnicalInformationHandler {
    static final long serialVersionUID = -3808427225599279539L;
    protected Classifier[] m_Models;
    protected int m_NumClasses;
    protected int[] m_Classes;
    protected Instances m_Attributes;
    protected double[] m_Beta;
    private int m_NumIterations = 100;
    protected int m_MaxIterations = 10;
    protected int m_DiscretizeBin = 0;
    protected Discretize m_Filter = null;
    protected MultiInstanceToPropositional m_ConvertToSI = new MultiInstanceToPropositional();

    /* loaded from: input_file:weka/classifiers/mi/MIBoost$OptEng.class */
    private class OptEng extends Optimization {
        private double[] weights;
        private double[] errs;

        private OptEng() {
        }

        public void setWeights(double[] dArr) {
            this.weights = dArr;
        }

        public void setErrs(double[] dArr) {
            this.errs = dArr;
        }

        @Override // weka.core.Optimization
        protected double objectiveFunction(double[] dArr) throws Exception {
            double d = 0.0d;
            for (int i = 0; i < this.weights.length; i++) {
                d += this.weights[i] * Math.exp(dArr[0] * ((2.0d * this.errs[i]) - 1.0d));
                if (Double.isNaN(d)) {
                    throw new Exception("Objective function value is NaN!");
                }
            }
            return d;
        }

        @Override // weka.core.Optimization
        protected double[] evaluateGradient(double[] dArr) throws Exception {
            double[] dArr2 = new double[1];
            for (int i = 0; i < this.weights.length; i++) {
                dArr2[0] = dArr2[0] + (this.weights[i] * ((2.0d * this.errs[i]) - 1.0d) * Math.exp(dArr[0] * ((2.0d * this.errs[i]) - 1.0d)));
                if (Double.isNaN(dArr2[0])) {
                    throw new Exception("Gradient is NaN!");
                }
            }
            return dArr2;
        }
    }

    public String globalInfo() {
        return "MI AdaBoost method, considers the geometric mean of posterior of instances inside a bag (arithmatic mean of log-posterior) and the expectation for a bag is taken inside the loss function.\n\nFor more information about Adaboost, see:\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Yoav Freund and Robert E. Schapire");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Experiments with a new boosting algorithm");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Thirteenth International Conference on Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1996");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "148-156");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "Morgan Kaufmann");
        technicalInformation.setValue(TechnicalInformation.Field.ADDRESS, "San Francisco");
        return technicalInformation;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        vector.addElement(new Option("\tThe number of bins in discretization\n\t(default 0, no discretization)", "B", 1, "-B <num>"));
        vector.addElement(new Option("\tMaximum number of boost iterations.\n\t(default 10)", "R", 1, "-R <num>"));
        vector.addElement(new Option("\tFull name of classifier to boost.\n\teg: weka.classifiers.bayes.NaiveBayes", "W", 1, "-W <class name>"));
        Enumeration listOptions = this.m_Classifier.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        setDebug(Utils.getFlag('D', strArr));
        String option = Utils.getOption('B', strArr);
        if (option.length() != 0) {
            setDiscretizeBin(Integer.parseInt(option));
        } else {
            setDiscretizeBin(0);
        }
        String option2 = Utils.getOption('R', strArr);
        if (option2.length() != 0) {
            setMaxIterations(Integer.parseInt(option2));
        } else {
            setMaxIterations(10);
        }
        super.setOptions(strArr);
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-R");
        vector.add("" + getMaxIterations());
        vector.add("-B");
        vector.add("" + getDiscretizeBin());
        for (String str : super.getOptions()) {
            vector.add(str);
        }
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public String maxIterationsTipText() {
        return "The maximum number of boost iterations.";
    }

    public void setMaxIterations(int i) {
        this.m_MaxIterations = i;
    }

    public int getMaxIterations() {
        return this.m_MaxIterations;
    }

    public String discretizeBinTipText() {
        return "The number of bins in discretization.";
    }

    public void setDiscretizeBin(int i) {
        this.m_DiscretizeBin = i;
    }

    public int getDiscretizeBin() {
        return this.m_DiscretizeBin;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        }
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    @Override // weka.core.MultiInstanceCapabilitiesHandler
    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        double[] dArr;
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_NumClasses = instances2.numClasses();
        this.m_NumIterations = this.m_MaxIterations;
        if (this.m_Classifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        if (!(this.m_Classifier instanceof WeightedInstancesHandler)) {
            throw new Exception("Base classifier cannot handle weighted instances!");
        }
        this.m_Models = Classifier.makeCopies(this.m_Classifier, getMaxIterations());
        if (this.m_Debug) {
            System.err.println("Base classifier: " + this.m_Classifier.getClass().getName());
        }
        this.m_Beta = new double[this.m_NumIterations];
        double numInstances = instances2.numInstances();
        double d = 0.0d;
        for (int i = 0; i < numInstances; i++) {
            d += instances2.instance(i).relationalValue(1).numInstances();
        }
        for (int i2 = 0; i2 < numInstances; i2++) {
            instances2.instance(i2).setWeight(d / numInstances);
        }
        this.m_ConvertToSI.setInputFormat(instances2);
        Instances useFilter = Filter.useFilter(instances2, this.m_ConvertToSI);
        useFilter.deleteAttributeAt(0);
        if (this.m_DiscretizeBin > 0) {
            this.m_Filter = new Discretize();
            this.m_Filter.setInputFormat(new Instances(useFilter, 0));
            this.m_Filter.setBins(this.m_DiscretizeBin);
            useFilter = Filter.useFilter(useFilter, this.m_Filter);
        }
        for (int i3 = 0; i3 < this.m_MaxIterations; i3++) {
            if (this.m_Debug) {
                System.err.println("\nIteration " + i3);
            }
            this.m_Models[i3].buildClassifier(useFilter);
            double[] dArr2 = new double[(int) numInstances];
            double[] dArr3 = new double[(int) numInstances];
            boolean z = true;
            boolean z2 = true;
            int i4 = 0;
            for (int i5 = 0; i5 < numInstances; i5++) {
                Instance instance = instances2.instance(i5);
                double numInstances2 = instance.relationalValue(1).numInstances();
                for (int i6 = 0; i6 < numInstances2; i6++) {
                    int i7 = i4;
                    i4++;
                    if (((int) this.m_Models[i3].classifyInstance(useFilter.instance(i7))) != ((int) instance.classValue())) {
                        int i8 = i5;
                        dArr2[i8] = dArr2[i8] + 1.0d;
                    }
                }
                dArr3[i5] = instance.weight();
                int i9 = i5;
                dArr2[i9] = dArr2[i9] / numInstances2;
                if (dArr2[i5] > 0.5d) {
                    z = false;
                }
                if (dArr2[i5] < 0.5d) {
                    z2 = false;
                }
            }
            if (z || z2) {
                if (i3 == 0) {
                    this.m_Beta[i3] = 1.0d;
                } else {
                    this.m_Beta[i3] = 0.0d;
                }
                this.m_NumIterations = i3 + 1;
                if (this.m_Debug) {
                    System.err.println("No errors");
                    return;
                }
                return;
            }
            double[] dArr4 = {KStarConstants.FLOOR};
            double[][] dArr5 = new double[2][dArr4.length];
            dArr5[0][0] = Double.NaN;
            dArr5[1][0] = Double.NaN;
            OptEng optEng = new OptEng();
            optEng.setWeights(dArr3);
            optEng.setErrs(dArr2);
            if (this.m_Debug) {
                System.out.println("Start searching for c... ");
            }
            double[] findArgmin = optEng.findArgmin(dArr4, dArr5);
            while (true) {
                dArr = findArgmin;
                if (dArr != null) {
                    break;
                }
                double[] varbValues = optEng.getVarbValues();
                if (this.m_Debug) {
                    System.out.println("200 iterations finished, not enough!");
                }
                findArgmin = optEng.findArgmin(varbValues, dArr5);
            }
            if (this.m_Debug) {
                System.out.println("Finished.");
            }
            this.m_Beta[i3] = dArr[0];
            if (this.m_Debug) {
                System.err.println("c = " + this.m_Beta[i3]);
            }
            if (Double.isInfinite(this.m_Beta[i3]) || Utils.smOrEq(this.m_Beta[i3], KStarConstants.FLOOR)) {
                if (i3 == 0) {
                    this.m_Beta[i3] = 1.0d;
                } else {
                    this.m_Beta[i3] = 0.0d;
                }
                this.m_NumIterations = i3 + 1;
                if (this.m_Debug) {
                    System.err.println("Errors out of range!");
                    return;
                }
                return;
            }
            int i10 = 0;
            double d2 = 0.0d;
            for (int i11 = 0; i11 < numInstances; i11++) {
                Instance instance2 = instances2.instance(i11);
                instance2.setWeight(dArr3[i11] * Math.exp(this.m_Beta[i3] * ((2.0d * dArr2[i11]) - 1.0d)));
                d2 += instance2.weight();
            }
            if (this.m_Debug) {
                System.err.println("Total weights = " + d2);
            }
            for (int i12 = 0; i12 < numInstances; i12++) {
                Instance instance3 = instances2.instance(i12);
                double numInstances3 = instance3.relationalValue(1).numInstances();
                instance3.setWeight((d * instance3.weight()) / d2);
                for (int i13 = 0; i13 < numInstances3; i13++) {
                    Instance instance4 = useFilter.instance(i10);
                    instance4.setWeight(instance3.weight() / numInstances3);
                    if (Double.isNaN(instance4.weight())) {
                        throw new Exception("instance " + i13 + " in bag " + i12 + " has weight NaN!");
                    }
                    i10++;
                }
            }
        }
    }

    @Override // weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArr = new double[this.m_NumClasses];
        Instances instances = new Instances(instance.dataset(), 0);
        instances.add(instance);
        Instances useFilter = Filter.useFilter(instances, this.m_ConvertToSI);
        useFilter.deleteAttributeAt(0);
        double numInstances = useFilter.numInstances();
        if (this.m_DiscretizeBin > 0) {
            useFilter = Filter.useFilter(useFilter, this.m_Filter);
        }
        for (int i = 0; i < numInstances; i++) {
            Instance instance2 = useFilter.instance(i);
            for (int i2 = 0; i2 < this.m_NumIterations; i2++) {
                int classifyInstance = (int) this.m_Models[i2].classifyInstance(instance2);
                dArr[classifyInstance] = dArr[classifyInstance] + (this.m_Beta[i2] / numInstances);
            }
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = Math.exp(dArr[i3]);
        }
        Utils.normalize(dArr);
        return dArr;
    }

    public String toString() {
        if (this.m_Models == null) {
            return "No model built yet!";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("MIBoost: number of bins in discretization = " + this.m_DiscretizeBin + "\n");
        if (this.m_NumIterations == 0) {
            stringBuffer.append("No model built yet.\n");
        } else if (this.m_NumIterations == 1) {
            stringBuffer.append("No boosting possible, one classifier used: Weight = " + Utils.roundDouble(this.m_Beta[0], 2) + "\n");
            stringBuffer.append("Base classifiers:\n" + this.m_Models[0].toString());
        } else {
            stringBuffer.append("Base classifiers and their weights: \n");
            for (int i = 0; i < this.m_NumIterations; i++) {
                stringBuffer.append("\n\n" + i + ": Weight = " + Utils.roundDouble(this.m_Beta[i], 2) + "\nBase classifier:\n" + this.m_Models[i].toString());
            }
        }
        stringBuffer.append("\n\nNumber of performed Iterations: " + this.m_NumIterations + "\n");
        return stringBuffer.toString();
    }

    public static void main(String[] strArr) {
        runClassifier(new MIBoost(), strArr);
    }
}
