package org.jpmml.evaluator.support_vector_machine;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.f;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Doubles;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.Array;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.support_vector_machine.Coefficient;
import org.dmg.pmml.support_vector_machine.Coefficients;
import org.dmg.pmml.support_vector_machine.Kernel;
import org.dmg.pmml.support_vector_machine.SupportVector;
import org.dmg.pmml.support_vector_machine.SupportVectorMachine;
import org.dmg.pmml.support_vector_machine.SupportVectorMachineModel;
import org.dmg.pmml.support_vector_machine.VectorDictionary;
import org.dmg.pmml.support_vector_machine.VectorFields;
import org.dmg.pmml.support_vector_machine.VectorInstance;
import org.jpmml.evaluator.ArrayUtil;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MissingValueException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.SparseArrayUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;
import org.jpmml.evaluator.VoteDistribution;
import org.jpmml.model.ReflectionUtil;

/* loaded from: classes6.dex */
public class SupportVectorMachineModelEvaluator extends ModelEvaluator<SupportVectorMachineModel> {
    private static final f<SupportVectorMachineModel, Map<String, double[]>> vectorCache = CacheUtil.buildLoadingCache(new CacheLoader<SupportVectorMachineModel, Map<String, double[]>>() { // from class: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public Map<String, double[]> load(SupportVectorMachineModel supportVectorMachineModel) {
            return ImmutableMap.copyOf(SupportVectorMachineModelEvaluator.parseVectorDictionary(supportVectorMachineModel));
        }
    });
    private transient Map<String, double[]> vectorMap;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator$2, reason: invalid class name */
    /* loaded from: classes6.dex */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod = new int[SupportVectorMachineModel.ClassificationMethod.values().length];
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation;

        static {
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ALL.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError unused3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError unused4) {
            }
            $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation = new int[SupportVectorMachineModel.Representation.values().length];
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation[SupportVectorMachineModel.Representation.SUPPORT_VECTORS.ordinal()] = 1;
            } catch (NoSuchFieldError unused5) {
            }
        }
    }

    public SupportVectorMachineModelEvaluator(PMML pmml) {
        this(pmml, (SupportVectorMachineModel) selectModel(pmml, SupportVectorMachineModel.class));
    }

    public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel) {
        super(pmml, supportVectorMachineModel);
        this.vectorMap = null;
        if (supportVectorMachineModel.isMaxWins()) {
            throw new UnsupportedFeatureException(supportVectorMachineModel);
        }
        SupportVectorMachineModel.Representation representation = supportVectorMachineModel.getRepresentation();
        if (AnonymousClass2.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation[representation.ordinal()] != 1) {
            throw new UnsupportedFeatureException(supportVectorMachineModel, representation);
        }
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        if (vectorDictionary == null) {
            throw new InvalidFeatureException(supportVectorMachineModel);
        }
        if (vectorDictionary.getVectorFields() == null) {
            throw new InvalidFeatureException(vectorDictionary);
        }
        if (!supportVectorMachineModel.hasSupportVectorMachines()) {
            throw new InvalidFeatureException(supportVectorMachineModel);
        }
    }

    private double[] createInput(EvaluationContext evaluationContext) {
        VectorFields vectorFields = getModel().getVectorDictionary().getVectorFields();
        List<PMMLObject> content = vectorFields.getContent();
        double[] dArr = new double[content.size()];
        for (int i = 0; i < content.size(); i++) {
            PMMLObject pMMLObject = content.get(i);
            if (pMMLObject instanceof FieldRef) {
                FieldRef fieldRef = (FieldRef) content.get(i);
                FieldName field = fieldRef.getField();
                FieldValue evaluate = ExpressionUtil.evaluate(fieldRef, evaluationContext);
                if (evaluate == null) {
                    throw new MissingValueException(field, vectorFields);
                }
                dArr[i] = evaluate.asNumber().doubleValue();
            } else {
                if (!(pMMLObject instanceof CategoricalPredictor)) {
                    throw new UnsupportedFeatureException(pMMLObject);
                }
                CategoricalPredictor categoricalPredictor = (CategoricalPredictor) pMMLObject;
                if (categoricalPredictor.getCoefficient() != 1.0d) {
                    throw new InvalidFeatureException(categoricalPredictor);
                }
                FieldName name = categoricalPredictor.getName();
                FieldValue evaluate2 = evaluationContext.evaluate(name);
                if (evaluate2 == null) {
                    throw new MissingValueException(name, categoricalPredictor);
                }
                dArr[i] = evaluate2.equals((HasValue<?>) categoricalPredictor) ? 1.0d : 0.0d;
            }
        }
        return dArr;
    }

    private Map<FieldName, ? extends Classification> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        Classification classification;
        SupportVectorMachineModel model = getModel();
        List<SupportVectorMachine> supportVectorMachines = model.getSupportVectorMachines();
        String alternateBinaryTargetCategory = model.getAlternateBinaryTargetCategory();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = getClassificationMethod();
        int i = AnonymousClass2.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[classificationMethod.ordinal()];
        if (i == 1) {
            classification = new Classification(Classification.Type.DISTANCE);
        } else {
            if (i != 2) {
                throw new UnsupportedFeatureException(model, classificationMethod);
            }
            classification = new VoteDistribution();
        }
        double[] createInput = createInput(modelEvaluationContext);
        for (SupportVectorMachine supportVectorMachine : supportVectorMachines) {
            String targetCategory = supportVectorMachine.getTargetCategory();
            String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();
            Double valueOf = Double.valueOf(evaluateSupportVectorMachine(supportVectorMachine, createInput));
            int i2 = AnonymousClass2.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[classificationMethod.ordinal()];
            if (i2 == 1) {
                if (targetCategory == null || alternateTargetCategory != null) {
                    throw new InvalidFeatureException(supportVectorMachine);
                }
                classification.put(targetCategory, valueOf);
            } else if (i2 != 2) {
                continue;
            } else if (alternateBinaryTargetCategory != null) {
                if (targetCategory == null || alternateTargetCategory != null) {
                    throw new InvalidFeatureException(supportVectorMachine);
                }
                long round = Math.round(valueOf.doubleValue());
                if (round != 1) {
                    if (round != 0) {
                        throw new EvaluationException("Invalid numeric prediction " + valueOf);
                    }
                    targetCategory = alternateBinaryTargetCategory;
                }
                Double d = classification.get(targetCategory);
                if (d == null) {
                    d = Double.valueOf(0.0d);
                }
                classification.put(targetCategory, Double.valueOf(d.doubleValue() + 1.0d));
            } else {
                if (targetCategory == null || alternateTargetCategory == null) {
                    throw new InvalidFeatureException(supportVectorMachine);
                }
                Double threshold = supportVectorMachine.getThreshold();
                if (threshold == null) {
                    threshold = model.getThreshold();
                }
                if (valueOf.compareTo(threshold) >= 0) {
                    targetCategory = alternateTargetCategory;
                }
                Double d2 = classification.get(targetCategory);
                if (d2 == null) {
                    d2 = Double.valueOf(0.0d);
                }
                classification.put(targetCategory, Double.valueOf(d2.doubleValue() + 1.0d));
            }
        }
        return TargetUtil.evaluateClassification(classification, modelEvaluationContext);
    }

    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext modelEvaluationContext) {
        SupportVectorMachineModel model = getModel();
        List<SupportVectorMachine> supportVectorMachines = model.getSupportVectorMachines();
        if (supportVectorMachines.size() == 1) {
            return TargetUtil.evaluateRegression(Double.valueOf(evaluateSupportVectorMachine(supportVectorMachines.get(0), createInput(modelEvaluationContext))), modelEvaluationContext);
        }
        throw new InvalidFeatureException(model);
    }

    private double evaluateSupportVectorMachine(SupportVectorMachine supportVectorMachine, double[] dArr) {
        Kernel kernel = getModel().getKernel();
        Coefficients coefficients = supportVectorMachine.getCoefficients();
        Iterator<Coefficient> it = coefficients.iterator();
        Iterator<SupportVector> it2 = supportVectorMachine.getSupportVectors().iterator();
        Map<String, double[]> vectorMap = getVectorMap();
        double d = 0.0d;
        while (it.hasNext() && it2.hasNext()) {
            Coefficient next = it.next();
            SupportVector next2 = it2.next();
            double[] dArr2 = vectorMap.get(next2.getVectorId());
            if (dArr2 == null) {
                throw new InvalidFeatureException(next2);
            }
            d += next.getValue().doubleValue() * KernelUtil.evaluate(kernel, dArr, dArr2);
        }
        if (it.hasNext() || it2.hasNext()) {
            throw new InvalidFeatureException(supportVectorMachine);
        }
        return d + coefficients.getAbsoluteValue().doubleValue();
    }

    private SupportVectorMachineModel.ClassificationMethod getClassificationMethod() {
        SupportVectorMachineModel model = getModel();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = (SupportVectorMachineModel.ClassificationMethod) ReflectionUtil.getFieldValue(ReflectionUtil.getField(SupportVectorMachineModel.class, "classificationMethod"), model);
        if (classificationMethod != null) {
            return classificationMethod;
        }
        List<SupportVectorMachine> supportVectorMachines = model.getSupportVectorMachines();
        if (model.getAlternateBinaryTargetCategory() != null) {
            if (supportVectorMachines.size() != 1) {
                throw new InvalidFeatureException(model);
            }
            SupportVectorMachine supportVectorMachine = supportVectorMachines.get(0);
            if (supportVectorMachine.getTargetCategory() != null) {
                return SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE;
            }
            throw new InvalidFeatureException(supportVectorMachine);
        }
        Iterator<SupportVectorMachine> it = supportVectorMachines.iterator();
        if (!it.hasNext()) {
            throw new InvalidFeatureException(model);
        }
        SupportVectorMachine next = it.next();
        String targetCategory = next.getTargetCategory();
        String alternateTargetCategory = next.getAlternateTargetCategory();
        if (targetCategory != null) {
            return alternateTargetCategory != null ? SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE : SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ALL;
        }
        throw new InvalidFeatureException(next);
    }

    private Map<String, double[]> getVectorMap() {
        if (this.vectorMap == null) {
            this.vectorMap = (Map) getValue(vectorCache);
        }
        return this.vectorMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<String, double[]> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel) {
        List<? extends Number> asNumberList;
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        List<PMMLObject> content = vectorDictionary.getVectorFields().getContent();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (VectorInstance vectorInstance : vectorDictionary.getVectorInstances()) {
            String id = vectorInstance.getId();
            if (id == null) {
                throw new InvalidFeatureException(vectorInstance);
            }
            Array array = vectorInstance.getArray();
            RealSparseArray realSparseArray = vectorInstance.getRealSparseArray();
            if (array != null && realSparseArray == null) {
                asNumberList = ArrayUtil.asNumberList(array);
            } else {
                if (array != null || realSparseArray == null) {
                    throw new InvalidFeatureException(vectorInstance);
                }
                asNumberList = SparseArrayUtil.asNumberList(realSparseArray);
            }
            if (content.size() != asNumberList.size()) {
                throw new InvalidFeatureException(vectorInstance);
            }
            linkedHashMap.put(id, Doubles.a(asNumberList));
        }
        return linkedHashMap;
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ?> evaluateRegression;
        SupportVectorMachineModel model = getModel();
        if (!model.isScorable()) {
            throw new InvalidResultException(model);
        }
        MiningFunction miningFunction = model.getMiningFunction();
        int i = AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()];
        if (i == 1) {
            evaluateRegression = evaluateRegression(modelEvaluationContext);
        } else {
            if (i != 2) {
                throw new UnsupportedFeatureException(model, miningFunction);
            }
            evaluateRegression = evaluateClassification(modelEvaluationContext);
        }
        return OutputUtil.evaluate(evaluateRegression, modelEvaluationContext);
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Support vector machine";
    }
}
