package com.xiaomi.migameservice.ml.tensorflow;

import android.app.Application;
import android.util.Log;
import com.xiaomi.migameservice.ml.Classifier;
import com.xiaomi.migameservice.ml.MLResult;
import com.xiaomi.migameservice.ml.ModelContext;
import com.xiaomi.migameservice.ml.datas.AudioFeature;
import com.xiaomi.migameservice.ml.datas.ModelInfo;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

/* loaded from: classes.dex */
public class AudioClassifier extends Classifier<AudioFeature> {
    public static final boolean DEBUG = true;
    public static final String PARAMS_KEY_CONFIDENCE = "confidence";
    private static final int RECORDING_LENGTH = 8000;
    public static final int SAMPLE_DURATION_MS = 500;
    private static final int SAMPLE_RATE = 16000;
    public static final String TAG = "AudioClassifier";
    protected TensorFlowInferenceInterface inferenceInterface;
    protected String mModelName;
    protected boolean logStats = false;
    protected String[] mInputTensors = {"decoded_sample_data:0", "decoded_sample_data:1"};
    protected String[] mOutputTensors = {"labels_softmax"};
    protected String mModelPbPath = null;
    private TFContext mModelContext = null;
    int[] SAMPLE_RATE_LIST = {SAMPLE_RATE};
    private HashMap<Integer, String> mLabelsMap = new HashMap<>();
    private float mConfidenceRatio = 0.5f;

    private void loadLabelsMap(ModelInfo modelInfo) {
        HashMap<String, String> labelTitleMap = modelInfo.getLabelTitleMap();
        for (String str : labelTitleMap.keySet()) {
            this.mLabelsMap.put(Integer.valueOf(str), labelTitleMap.get(str));
        }
        this.mLabelsMap.put(-1, "unknown");
    }

    @Override // com.xiaomi.migameservice.ml.Classifier
    public boolean applyModelInfo(Application application, ModelInfo modelInfo) {
        this.mModelName = modelInfo.getModelName();
        this.mInputTensors = (String[]) modelInfo.getInputTensorNames().toArray(new String[modelInfo.getInputTensorNames().size()]);
        this.mOutputTensors = (String[]) modelInfo.getOutputTensorNames().toArray(new String[modelInfo.getOutputTensorNames().size()]);
        HashMap<String, String> params = modelInfo.getParams();
        if (params != null && params.containsKey(PARAMS_KEY_CONFIDENCE)) {
            this.mConfidenceRatio = Float.parseFloat(params.get(PARAMS_KEY_CONFIDENCE));
        }
        for (int i = 0; i < this.mInputTensors.length; i++) {
            Log.d(TAG, "mInputTensors: " + this.mInputTensors[i]);
        }
        for (int i2 = 0; i2 < this.mOutputTensors.length; i2++) {
            Log.d(TAG, "mOutputTensorNames: " + this.mOutputTensors[i2]);
        }
        this.mModelPbPath = modelInfo.getLocalModelPath();
        loadLabelsMap(modelInfo);
        return true;
    }

    protected int argmax(float[] fArr) {
        int i = 0;
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (fArr[i2] > fArr[i]) {
                i = i2;
            }
        }
        return i;
    }

    @Override // com.xiaomi.migameservice.ml.Classifier
    public void enableStatLogging(boolean z) {
        this.logStats = z;
    }

    @Override // com.xiaomi.migameservice.ml.Classifier
    public boolean isStarted() {
        return this.inferenceInterface != null;
    }

    @Override // com.xiaomi.migameservice.ml.Classifier
    protected void loadModel(ModelContext modelContext) {
        if (modelContext instanceof TFContext) {
            this.mModelContext = (TFContext) modelContext;
            return;
        }
        Log.e(TAG, "wrong model context with modelContext : " + modelContext);
    }

    public List<MLResult> recognize(List<short[]> list) {
        try {
            return recognizeInner(list);
        } catch (IllegalStateException e) {
            e.printStackTrace();
            ArrayList arrayList = new ArrayList();
            arrayList.add("unknown");
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(new MLResult(arrayList));
            return arrayList2;
        } catch (Exception e2) {
            e2.printStackTrace();
            ArrayList arrayList3 = new ArrayList();
            arrayList3.add("unknown");
            ArrayList arrayList22 = new ArrayList();
            arrayList22.add(new MLResult(arrayList3));
            return arrayList22;
        }
    }

    @Override // com.xiaomi.migameservice.ml.Classifier
    public List<MLResult> recognizeImage(List<AudioFeature> list) {
        return null;
    }

    public List<MLResult> recognizeInner(List<short[]> list) {
        long currentTimeMillis = System.currentTimeMillis();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (!isStarted()) {
            Log.w(TAG, toString() + " is not started yet, skip task.");
            arrayList.add("unknown");
            arrayList2.add(new MLResult(arrayList));
            return arrayList2;
        }
        float[] fArr = new float[RECORDING_LENGTH];
        float[] fArr2 = new float[this.mLabelsMap.size()];
        for (int i = 0; i < RECORDING_LENGTH; i++) {
            fArr[i] = list.get(0)[i] / 32767.0f;
        }
        this.inferenceInterface.feed(this.mInputTensors[1], this.SAMPLE_RATE_LIST, new long[0]);
        this.inferenceInterface.feed(this.mInputTensors[0], fArr, 8000, 1);
        this.inferenceInterface.run(this.mOutputTensors);
        this.inferenceInterface.fetch(this.mOutputTensors[0], fArr2);
        int argmax = argmax(fArr2);
        Log.d(TAG, "[WindEar] AudioRecognize ret=" + argmax + ", score=" + fArr2[argmax]);
        if (argmax == 0 || fArr2[argmax] <= this.mConfidenceRatio) {
            arrayList.add("silence");
        } else {
            Log.d(TAG, "[WindEar] AudioRecognize recognized: " + this.mLabelsMap.get(Integer.valueOf(argmax)));
            arrayList.add(this.mLabelsMap.get(Integer.valueOf(argmax)));
        }
        arrayList2.add(new MLResult(arrayList));
        Log.d(TAG, "[WindEar] Audio recognize spent time : " + String.valueOf(System.currentTimeMillis() - currentTimeMillis));
        return arrayList2;
    }

    @Override // com.xiaomi.migameservice.ml.Classifier
    public void start() {
        try {
            FileInputStream fileInputStream = new FileInputStream(new File(this.mModelPbPath));
            this.inferenceInterface = new TensorFlowInferenceInterface(fileInputStream);
            fileInputStream.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    @Override // com.xiaomi.migameservice.ml.Classifier
    public void stop() {
        if (this.inferenceInterface != null) {
            this.inferenceInterface.close();
        }
        this.inferenceInterface = null;
    }

    @Override // com.xiaomi.migameservice.ml.Classifier
    protected void unloadModel() {
        this.mModelContext = null;
    }
}
