package com.xiam.consia.ml.attributeselection;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import com.xiam.consia.ml.data.DataRecord;
import com.xiam.consia.ml.data.DataRecords;
import com.xiam.consia.ml.data.attribute.Attribute;
import com.xiam.consia.ml.tree.SplitInfoBuilder;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: classes.dex */
public class InformationGain extends AttributeSelection {
    private static final DATA_RECORD_SORT_ON_NUMERIC_ATTRIBUTE sortRecordByDoubleAttribute = new DATA_RECORD_SORT_ON_NUMERIC_ATTRIBUTE();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: classes.dex */
    public static final class DATA_RECORD_SORT_ON_NUMERIC_ATTRIBUTE extends Ordering<DataRecord> {
        private DATA_RECORD_SORT_ON_NUMERIC_ATTRIBUTE() {
        }

        @Override // com.google.common.collect.Ordering, java.util.Comparator
        public int compare(DataRecord dataRecord, DataRecord dataRecord2) {
            return Double.compare(dataRecord.getNumericValueToSortOn(), dataRecord2.getNumericValueToSortOn());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public InformationGain(boolean z, int i) {
        super(z, i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Short asShort(long j) {
        return Short.valueOf((short) j);
    }

    private static double calcClassEntropy(Collection<Short> collection, double d) {
        Iterator<Short> it = collection.iterator();
        double d2 = 0.0d;
        while (it.hasNext()) {
            double doubleValue = it.next().doubleValue() / d;
            d2 = doubleValue > 0.0d ? (doubleValue * Math.log(doubleValue)) + d2 : d2;
        }
        return -d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double calcY(Map<String, Short> map, double d, Set<String> set) {
        if (d <= 0.0d) {
            return -0.0d;
        }
        double d2 = 0.0d;
        for (String str : set) {
            if (map.containsKey(str)) {
                double doubleValue = map.get(str).doubleValue() / d;
                if (doubleValue > 0.0d) {
                    d2 += doubleValue * Math.log(doubleValue);
                }
            }
            d2 = d2;
        }
        return -d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double calculateInformationGain(Map<String, Short> map, AttributeCounter attributeCounter, int i) {
        return calcClassEntropy(map.values(), i) - attributeCounter.calcEntropyConditionedOnAttribute(map.keySet(), i);
    }

    private int getAttributeIndex(DataRecords dataRecords, int i, Set<Integer> set) {
        if (!isDoRandomisation()) {
            return i;
        }
        int nextInt = random.nextInt(dataRecords.getNumAttributes());
        while (set.contains(Integer.valueOf(nextInt))) {
            nextInt = random.nextInt(dataRecords.getNumAttributes());
        }
        set.add(Integer.valueOf(nextInt));
        return nextInt;
    }

    private int getNumAttributesToEvaluate(DataRecords dataRecords) {
        return isDoRandomisation() ? getNumRandomAttsForSplitCriteria() : dataRecords.getNumAttributes();
    }

    private void setGainAndSplitInfoForContinuousAttribute(List<DataRecord> list, int i, Map<String, Short> map, double[] dArr, double[] dArr2) {
        List<DataRecord> sortRecordsByAttribute = sortRecordsByAttribute(list, i);
        ContinuousAttributeCounter create = ContinuousAttributeCounter.create(list.size(), map);
        dArr[i] = Double.MIN_VALUE;
        dArr2[i] = -1.0d;
        double doubleValue = sortRecordsByAttribute.get(0).getDoubleAttributeValue(i).doubleValue();
        int i2 = 1;
        double d = doubleValue;
        for (DataRecord dataRecord : sortRecordsByAttribute) {
            double doubleValue2 = dataRecord.getDoubleAttributeValue(i).doubleValue();
            if (doubleValue2 > d) {
                double calculateGain = calculateGain(map, create, sortRecordsByAttribute.size());
                if (calculateGain > dArr[i]) {
                    dArr[i] = calculateGain;
                    dArr2[i] = (d + doubleValue2) / 2.0d;
                }
            }
            create.updateValueCounts(sortRecordsByAttribute.size(), i2);
            create.updateAttributeCountsPerClassMaps(dataRecord.getClassLabel());
            i2++;
            d = doubleValue2;
        }
    }

    private void setGainForAttribute(List<DataRecord> list, boolean z, int i, Map<String, Short> map, double[] dArr, double[] dArr2) {
        if (!z) {
            setGainAndSplitInfoForContinuousAttribute(list, i, map, dArr, dArr2);
        } else {
            setGainForDiscreteAttribute(list, i, map, dArr);
            dArr2[i] = Double.MIN_VALUE;
        }
    }

    private void setGainForDiscreteAttribute(Collection<DataRecord> collection, int i, Map<String, Short> map, double[] dArr) {
        dArr[i] = calculateGain(map, DiscreteAttributeCounter.createDiscrete(collection, i), collection.size());
    }

    private static List<DataRecord> sortRecordsByAttribute(List<DataRecord> list, int i) {
        for (DataRecord dataRecord : list) {
            dataRecord.setNumericValueToSortOn(dataRecord.getDoubleAttributeValue(i).doubleValue());
        }
        Collections.sort(list, sortRecordByDoubleAttribute);
        return list;
    }

    protected double calculateGain(Map<String, Short> map, AttributeCounter attributeCounter, int i) {
        return calculateInformationGain(map, attributeCounter, i);
    }

    @Override // com.xiam.consia.ml.attributeselection.AttributeSelection
    public SplitInfoBuilder findOptimumSplit(DataRecords dataRecords, Map<String, Short> map) {
        return findOptimumSplit(dataRecords, map, new double[dataRecords.getNumAttributes()]);
    }

    @VisibleForTesting
    SplitInfoBuilder findOptimumSplit(DataRecords dataRecords, Map<String, Short> map, double[] dArr) {
        boolean z;
        double d;
        int i;
        double d2;
        int numAttributesToEvaluate = getNumAttributesToEvaluate(dataRecords);
        double d3 = Double.MIN_VALUE;
        int i2 = -1;
        double d4 = Double.MIN_VALUE;
        boolean z2 = false;
        double[] dArr2 = new double[dataRecords.getNumAttributes()];
        HashSet newHashSet = Sets.newHashSet();
        int i3 = 0;
        while (i3 < numAttributesToEvaluate) {
            int attributeIndex = getAttributeIndex(dataRecords, i3, newHashSet);
            Attribute attribute = dataRecords.getAttributes().get(attributeIndex);
            setGainForAttribute(dataRecords.getDataRecords(), attribute.isDiscrete, attributeIndex, map, dArr, dArr2);
            if (dArr[attributeIndex] <= 0.0d || dArr[attributeIndex] <= d3) {
                z = z2;
                d = d4;
                i = i2;
                d2 = d3;
            } else {
                double d5 = dArr[attributeIndex];
                double d6 = dArr2[attributeIndex];
                z = attribute.isDiscrete;
                i = attributeIndex;
                d = d6;
                d2 = d5;
            }
            i3++;
            z2 = z;
            d4 = d;
            i2 = i;
            d3 = d2;
        }
        return SplitInfoBuilder.build(i2, d4, z2, d3 > 0.0d);
    }
}
