package com.mindspore.styletransfer;

import android.content.Context;
import android.graphics.Bitmap;
import android.os.SystemClock;
import android.util.Log;
import com.mindspore.lite.LiteSession;
import com.mindspore.lite.MSTensor;
import com.mindspore.lite.Model;
import com.mindspore.lite.config.MSConfig;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/* loaded from: classes.dex */
public class StyleTransferModelExecutor {
    private static final int CONTENT_IMAGE_SIZE = 384;
    private static final int STYLE_IMAGE_SIZE = 256;
    private static final String TAG = "StyleTransferModelExecutor";
    private final int NUM_THREADS = 4;
    private LiteSession Predict_session;
    private LiteSession Transform_session;
    private long fullExecutionTime;
    private Context mContext;
    private LinkedHashMap<String, MSTensor> mOutputs;
    private MSConfig msConfig;
    private long postProcessTime;
    private long preProcessTime;
    private long stylePredictTime;
    private long styleTransferTime;
    private Model style_predict_model;
    private Model style_transform_model;

    public StyleTransferModelExecutor(Context context, boolean z) {
        this.mContext = context;
        init();
    }

    public static byte[] floatArrayToByteArray(float[] fArr) {
        ByteBuffer allocate = ByteBuffer.allocate(fArr.length * 4);
        allocate.order(ByteOrder.nativeOrder());
        allocate.asFloatBuffer().put(fArr);
        return allocate.array();
    }

    private String formatExecutionLog() {
        StringBuilder sb = new StringBuilder();
        sb.append("Input Image Size:147456");
        sb.append("\nPre-process execution time: " + this.preProcessTime + " ms");
        sb.append("\nPredicting style execution time: " + this.stylePredictTime + " ms");
        sb.append("\nTransferring style execution time: " + this.styleTransferTime + " ms");
        sb.append("\nPost-process execution time: " + this.postProcessTime + " ms");
        sb.append("\nFull execution time: " + this.fullExecutionTime + " ms");
        return sb.toString();
    }

    private float[][][][] getDealData(ByteBuffer byteBuffer) {
        List<String> outputTensorNames = this.Predict_session.getOutputTensorNames();
        Map<String, MSTensor> outputMapByTensor = this.Predict_session.getOutputMapByTensor();
        float[] fArr = null;
        for (String str : outputTensorNames) {
            MSTensor mSTensor = outputMapByTensor.get(str);
            if (mSTensor == null) {
                Log.e("MS_LITE", "Can not find Predict_session output " + str);
                return (float[][][][]) null;
            }
            fArr = mSTensor.getFloatData();
        }
        List<MSTensor> inputs = this.Transform_session.getInputs();
        inputs.get(0).setData(floatArrayToByteArray(fArr));
        inputs.get(1).setData(byteBuffer);
        this.styleTransferTime = SystemClock.uptimeMillis();
        if (!this.Transform_session.runGraph()) {
            Log.e("MS_LITE", "Run Transform_graph failed");
            return (float[][][][]) null;
        }
        this.styleTransferTime = SystemClock.uptimeMillis() - this.styleTransferTime;
        Log.d(TAG, "Style apply Time to run: " + this.styleTransferTime);
        this.postProcessTime = SystemClock.uptimeMillis();
        List<String> outputTensorNames2 = this.Transform_session.getOutputTensorNames();
        Map<String, MSTensor> outputMapByTensor2 = this.Transform_session.getOutputMapByTensor();
        float[] fArr2 = null;
        for (String str2 : outputTensorNames2) {
            MSTensor mSTensor2 = outputMapByTensor2.get(str2);
            if (mSTensor2 == null) {
                Log.e("MS_LITE", "Can not find Transform_session output " + str2);
                return (float[][][][]) null;
            }
            fArr2 = mSTensor2.getFloatData();
        }
        float[][][][] fArr3 = new float[1][][];
        for (int i = 0; i < 1; i++) {
            float[][][] fArr4 = new float[CONTENT_IMAGE_SIZE][];
            for (int i2 = 0; i2 < CONTENT_IMAGE_SIZE; i2++) {
                float[][] fArr5 = new float[CONTENT_IMAGE_SIZE];
                for (int i3 = 0; i3 < CONTENT_IMAGE_SIZE; i3++) {
                    float[] fArr6 = new float[3];
                    for (int i4 = 0; i4 < 3; i4++) {
                        fArr6[i4] = fArr2[(i3 * 3) + i4 + (i2 * CONTENT_IMAGE_SIZE * 3) + (i * CONTENT_IMAGE_SIZE * CONTENT_IMAGE_SIZE * 3)];
                    }
                    fArr5[i3] = fArr6;
                }
                fArr4[i2] = fArr5;
            }
            fArr3[i] = fArr4;
        }
        return fArr3;
    }

    public ModelExecutionResult execute(Bitmap bitmap, Bitmap bitmap2) {
        Log.i(TAG, "running models");
        this.fullExecutionTime = SystemClock.uptimeMillis();
        this.preProcessTime = SystemClock.uptimeMillis();
        ByteBuffer bitmapToByteBuffer = ImageUtils.bitmapToByteBuffer(bitmap, CONTENT_IMAGE_SIZE, CONTENT_IMAGE_SIZE, 0.0f, 255.0f);
        ByteBuffer bitmapToByteBuffer2 = ImageUtils.bitmapToByteBuffer(bitmap2, 256, 256, 0.0f, 255.0f);
        List<MSTensor> inputs = this.Predict_session.getInputs();
        if (inputs.size() != 1) {
            return null;
        }
        inputs.get(0).setData(bitmapToByteBuffer2);
        this.preProcessTime = SystemClock.uptimeMillis() - this.preProcessTime;
        this.stylePredictTime = SystemClock.uptimeMillis();
        if (!this.Predict_session.runGraph()) {
            Log.e("MS_LITE", "Run Predict_graph failed");
            return null;
        }
        this.stylePredictTime = SystemClock.uptimeMillis() - this.stylePredictTime;
        Log.d(TAG, "Style Predict Time to run: " + this.stylePredictTime);
        Bitmap convertArrayToBitmap = ImageUtils.convertArrayToBitmap(getDealData(bitmapToByteBuffer), CONTENT_IMAGE_SIZE, CONTENT_IMAGE_SIZE);
        this.postProcessTime = SystemClock.uptimeMillis() - this.postProcessTime;
        this.fullExecutionTime = SystemClock.uptimeMillis() - this.fullExecutionTime;
        Log.d(TAG, "Time to run everything: $" + this.fullExecutionTime);
        return new ModelExecutionResult(convertArrayToBitmap, this.preProcessTime, this.stylePredictTime, this.styleTransferTime, this.postProcessTime, this.fullExecutionTime, formatExecutionLog());
    }

    public void init() {
        Model model = new Model();
        this.style_predict_model = model;
        if (!model.loadModel(this.mContext, "style_predict_quant.ms")) {
            Log.e("MS_LITE", "Load style_predict_model failed");
        }
        Model model2 = new Model();
        this.style_transform_model = model2;
        if (!model2.loadModel(this.mContext, "style_transfer_quant.ms")) {
            Log.e("MS_LITE", "Load style_transform_model failed");
        }
        MSConfig mSConfig = new MSConfig();
        this.msConfig = mSConfig;
        if (!mSConfig.init(0, 4, 2)) {
            Log.e("MS_LITE", "Init context failed");
        }
        LiteSession liteSession = new LiteSession();
        this.Predict_session = liteSession;
        if (!liteSession.init(this.msConfig)) {
            Log.e("MS_LITE", "Create Predict_session failed");
            this.msConfig.free();
        }
        LiteSession liteSession2 = new LiteSession();
        this.Transform_session = liteSession2;
        if (!liteSession2.init(this.msConfig)) {
            Log.e("MS_LITE", "Create Predict_session failed");
            this.msConfig.free();
        }
        this.msConfig.free();
        if (!this.Predict_session.compileGraph(this.style_predict_model)) {
            Log.e("MS_LITE", "Compile style_predict graph failed");
            this.style_predict_model.freeBuffer();
        }
        if (!this.Transform_session.compileGraph(this.style_transform_model)) {
            Log.e("MS_LITE", "Compile style_transform graph failed");
            this.style_transform_model.freeBuffer();
        }
        this.style_predict_model.freeBuffer();
        this.style_transform_model.freeBuffer();
    }
}
