package org.jpmml.evaluator.neural_network;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.f;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.j;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.neural_network.Connection;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EntityProbabilityDistribution;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NormalizationUtil;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;

/* loaded from: classes6.dex */
public class NeuralNetworkEvaluator extends ModelEvaluator<NeuralNetwork> implements HasEntityRegistry<Entity> {
    private static final f<NeuralNetwork, j<String, Entity>> entityCache = CacheUtil.buildLoadingCache(new CacheLoader<NeuralNetwork, j<String, Entity>>() { // from class: org.jpmml.evaluator.neural_network.NeuralNetworkEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public j<String, Entity> load(NeuralNetwork neuralNetwork) {
            ImmutableBiMap.a aVar = new ImmutableBiMap.a();
            AtomicInteger atomicInteger = new AtomicInteger(1);
            Iterator<NeuralInput> it = neuralNetwork.getNeuralInputs().iterator();
            while (it.hasNext()) {
                aVar = EntityUtil.put(it.next(), atomicInteger, aVar);
            }
            Iterator<NeuralLayer> it2 = neuralNetwork.getNeuralLayers().iterator();
            while (it2.hasNext()) {
                List<Neuron> neurons = it2.next().getNeurons();
                for (int i = 0; i < neurons.size(); i++) {
                    aVar = EntityUtil.put(neurons.get(i), atomicInteger, aVar);
                }
            }
            return aVar.b();
        }
    });
    private transient j<String, Entity> entityRegistry;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.evaluator.neural_network.NeuralNetworkEvaluator$2, reason: invalid class name */
    /* loaded from: classes6.dex */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod;

        static {
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.THRESHOLD.ordinal()] = 1;
            } catch (NoSuchFieldError unused) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.LOGISTIC.ordinal()] = 2;
            } catch (NoSuchFieldError unused2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.TANH.ordinal()] = 3;
            } catch (NoSuchFieldError unused3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.IDENTITY.ordinal()] = 4;
            } catch (NoSuchFieldError unused4) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.EXPONENTIAL.ordinal()] = 5;
            } catch (NoSuchFieldError unused5) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.RECIPROCAL.ordinal()] = 6;
            } catch (NoSuchFieldError unused6) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.SQUARE.ordinal()] = 7;
            } catch (NoSuchFieldError unused7) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.GAUSS.ordinal()] = 8;
            } catch (NoSuchFieldError unused8) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.SINE.ordinal()] = 9;
            } catch (NoSuchFieldError unused9) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.COSINE.ordinal()] = 10;
            } catch (NoSuchFieldError unused10) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.ELLIOTT.ordinal()] = 11;
            } catch (NoSuchFieldError unused11) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.ARCTAN.ordinal()] = 12;
            } catch (NoSuchFieldError unused12) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.RECTIFIER.ordinal()] = 13;
            } catch (NoSuchFieldError unused13) {
            }
            $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod = new int[NeuralNetwork.NormalizationMethod.values().length];
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod[NeuralNetwork.NormalizationMethod.NONE.ordinal()] = 1;
            } catch (NoSuchFieldError unused14) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod[NeuralNetwork.NormalizationMethod.SIMPLEMAX.ordinal()] = 2;
            } catch (NoSuchFieldError unused15) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod[NeuralNetwork.NormalizationMethod.SOFTMAX.ordinal()] = 3;
            } catch (NoSuchFieldError unused16) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError unused17) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError unused18) {
            }
        }
    }

    public NeuralNetworkEvaluator(PMML pmml) {
        this(pmml, (NeuralNetwork) selectModel(pmml, NeuralNetwork.class));
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
        this.entityRegistry = null;
        NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
        if (neuralInputs == null) {
            throw new InvalidFeatureException(neuralNetwork);
        }
        if (!neuralInputs.hasNeuralInputs()) {
            throw new InvalidFeatureException(neuralInputs);
        }
        if (!neuralNetwork.hasNeuralLayers()) {
            throw new InvalidFeatureException(neuralNetwork);
        }
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new InvalidFeatureException(neuralNetwork);
        }
        if (!neuralOutputs.hasNeuralOutputs()) {
            throw new InvalidFeatureException(neuralOutputs);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double activation(double d, NeuralLayer neuralLayer) {
        NeuralLayer neuralLayer2;
        NeuralNetwork model = getModel();
        NeuralNetwork.ActivationFunction activationFunction = neuralLayer.getActivationFunction();
        if (activationFunction == null) {
            activationFunction = model.getActivationFunction();
            neuralLayer2 = model;
        } else {
            neuralLayer2 = neuralLayer;
        }
        if (activationFunction == null) {
            throw new InvalidFeatureException(neuralLayer);
        }
        switch (activationFunction) {
            case THRESHOLD:
                Double threshold = neuralLayer.getThreshold();
                if (threshold == null) {
                    threshold = model.getThreshold();
                }
                if (threshold != null) {
                    return d > threshold.doubleValue() ? 1.0d : 0.0d;
                }
                throw new InvalidFeatureException(neuralLayer);
            case LOGISTIC:
                return 1.0d / (Math.exp(-d) + 1.0d);
            case TANH:
                return Math.tanh(d);
            case IDENTITY:
                return d;
            case EXPONENTIAL:
                return Math.exp(d);
            case RECIPROCAL:
                return 1.0d / d;
            case SQUARE:
                return d * d;
            case GAUSS:
                return Math.exp(-(d * d));
            case SINE:
                return Math.sin(d);
            case COSINE:
                return Math.cos(d);
            case ELLIOTT:
                return d / (Math.abs(d) + 1.0d);
            case ARCTAN:
                return Math.atan(d);
            case RECTIFIER:
                return Math.max(0.0d, d);
            default:
                throw new UnsupportedFeatureException(neuralLayer2, activationFunction);
        }
    }

    private Map<FieldName, ? extends Classification> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        NeuralNetwork model = getModel();
        j<String, Entity> entityRegistry = getEntityRegistry();
        Map<String, Double> evaluateRaw = evaluateRaw(modelEvaluationContext);
        if (evaluateRaw == null) {
            return TargetUtil.evaluateClassificationDefault(modelEvaluationContext);
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<NeuralOutput> it = model.getNeuralOutputs().iterator();
        while (it.hasNext()) {
            NeuralOutput next = it.next();
            String outputNeuron = next.getOutputNeuron();
            Entity entity = entityRegistry.get(outputNeuron);
            Expression outputExpression = getOutputExpression(next);
            if (!(outputExpression instanceof NormDiscrete)) {
                throw new UnsupportedFeatureException(outputExpression);
            }
            NormDiscrete normDiscrete = (NormDiscrete) outputExpression;
            FieldName field = normDiscrete.getField();
            EntityProbabilityDistribution entityProbabilityDistribution = (EntityProbabilityDistribution) linkedHashMap.get(field);
            if (entityProbabilityDistribution == null) {
                entityProbabilityDistribution = new EntityProbabilityDistribution(entityRegistry);
                linkedHashMap.put(field, entityProbabilityDistribution);
            }
            entityProbabilityDistribution.put(entity, normDiscrete.getValue(), evaluateRaw.get(outputNeuron));
        }
        for (TargetField targetField : getTargetFields()) {
            FieldName name = targetField.getName();
            linkedHashMap.put(name, TargetUtil.evaluateClassificationInternal(targetField, (Classification) linkedHashMap.get(name), modelEvaluationContext));
        }
        return linkedHashMap;
    }

    private Map<String, Double> evaluateRaw(EvaluationContext evaluationContext) {
        NeuralNetwork model = getModel();
        HashMap hashMap = new HashMap(getEntityRegistry().size());
        Iterator<NeuralInput> it = model.getNeuralInputs().iterator();
        while (it.hasNext()) {
            NeuralInput next = it.next();
            FieldValue evaluate = ExpressionUtil.evaluate(next.getDerivedField(), evaluationContext);
            if (evaluate == null) {
                return null;
            }
            hashMap.put(next.getId(), evaluate.asDouble());
        }
        HashMap hashMap2 = new HashMap();
        for (NeuralLayer neuralLayer : model.getNeuralLayers()) {
            hashMap2.clear();
            List<Neuron> neurons = neuralLayer.getNeurons();
            for (int i = 0; i < neurons.size(); i++) {
                Neuron neuron = neurons.get(i);
                List<Connection> connections = neuron.getConnections();
                double d = 0.0d;
                for (int i2 = 0; i2 < connections.size(); i2++) {
                    Connection connection = connections.get(i2);
                    d += ((Double) hashMap.get(connection.getFrom())).doubleValue() * connection.getWeight();
                }
                Double bias = neuron.getBias();
                if (bias != null) {
                    d += bias.doubleValue();
                }
                hashMap2.put(neuron.getId(), Double.valueOf(activation(d, neuralLayer)));
            }
            normalizeNeuronOutputs(neuralLayer, hashMap2);
            hashMap.putAll(hashMap2);
        }
        return hashMap;
    }

    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext modelEvaluationContext) {
        NeuralNetwork model = getModel();
        Map<String, Double> evaluateRaw = evaluateRaw(modelEvaluationContext);
        if (evaluateRaw == null) {
            return TargetUtil.evaluateRegressionDefault(modelEvaluationContext);
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<NeuralOutput> it = model.getNeuralOutputs().iterator();
        while (it.hasNext()) {
            NeuralOutput next = it.next();
            String outputNeuron = next.getOutputNeuron();
            Expression outputExpression = getOutputExpression(next);
            if (outputExpression instanceof FieldRef) {
                linkedHashMap.put(((FieldRef) outputExpression).getField(), evaluateRaw.get(outputNeuron));
            } else {
                if (!(outputExpression instanceof NormContinuous)) {
                    throw new UnsupportedFeatureException(outputExpression);
                }
                NormContinuous normContinuous = (NormContinuous) outputExpression;
                linkedHashMap.put(normContinuous.getField(), Double.valueOf(NormalizationUtil.denormalize(normContinuous, evaluateRaw.get(outputNeuron).doubleValue())));
            }
        }
        for (TargetField targetField : getTargetFields()) {
            FieldName name = targetField.getName();
            linkedHashMap.put(name, TargetUtil.evaluateRegressionInternal(targetField, linkedHashMap.get(name), modelEvaluationContext));
        }
        return linkedHashMap;
    }

    private Expression getOutputExpression(NeuralOutput neuralOutput) {
        DerivedField derivedField = neuralOutput.getDerivedField();
        if (derivedField == null) {
            throw new InvalidFeatureException(neuralOutput);
        }
        Expression expression = derivedField.getExpression();
        if (expression == null) {
            throw new InvalidFeatureException(derivedField);
        }
        if (!(expression instanceof FieldRef)) {
            return expression;
        }
        FieldRef fieldRef = (FieldRef) expression;
        FieldName field = fieldRef.getField();
        TypeDefinitionField resolveField = resolveField(field);
        if (resolveField == null) {
            throw new MissingFieldException(field, fieldRef);
        }
        if (resolveField instanceof DataField) {
            return expression;
        }
        if (!(resolveField instanceof DerivedField)) {
            throw new InvalidFeatureException(fieldRef);
        }
        DerivedField derivedField2 = (DerivedField) resolveField;
        Expression expression2 = derivedField2.getExpression();
        if (expression2 != null) {
            return expression2;
        }
        throw new InvalidFeatureException(derivedField2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void normalizeNeuronOutputs(NeuralLayer neuralLayer, Map<String, Double> map) {
        NeuralNetwork model = getModel();
        NeuralNetwork.NormalizationMethod normalizationMethod = neuralLayer.getNormalizationMethod();
        if (normalizationMethod == null) {
            normalizationMethod = model.getNormalizationMethod();
            neuralLayer = model;
        }
        int i = AnonymousClass2.$SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod[normalizationMethod.ordinal()];
        if (i != 1) {
            if (i == 2) {
                Classification.normalize(map);
            } else {
                if (i != 3) {
                    throw new UnsupportedFeatureException(neuralLayer, normalizationMethod);
                }
                Classification.normalizeSoftMax(map);
            }
        }
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ?> evaluateRegression;
        NeuralNetwork model = getModel();
        if (!model.isScorable()) {
            throw new InvalidResultException(model);
        }
        MiningFunction miningFunction = model.getMiningFunction();
        int i = AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()];
        if (i == 1) {
            evaluateRegression = evaluateRegression(modelEvaluationContext);
        } else {
            if (i != 2) {
                throw new UnsupportedFeatureException(model, miningFunction);
            }
            evaluateRegression = evaluateClassification(modelEvaluationContext);
        }
        return OutputUtil.evaluate(evaluateRegression, modelEvaluationContext);
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public j<String, Entity> getEntityRegistry() {
        if (this.entityRegistry == null) {
            this.entityRegistry = (j) getValue(entityCache);
        }
        return this.entityRegistry;
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Neural network";
    }
}
