package org.tensorflow.lite.support.b;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.a.e;

/* compiled from: TensorLabel.java */
/* loaded from: classes8.dex */
public class c {
    private final Map<Integer, List<String>> a;
    private final org.tensorflow.lite.support.c.a b;
    private final int[] c;

    public c(@NonNull List<String> list, @NonNull org.tensorflow.lite.support.c.a aVar) {
        this(a(a(aVar), list), aVar);
    }

    public c(@NonNull Map<Integer, List<String>> map, @NonNull org.tensorflow.lite.support.c.a aVar) {
        e.a(map, "Axis labels cannot be null.");
        e.a(aVar, "Tensor Buffer cannot be null.");
        this.a = map;
        this.b = aVar;
        this.c = aVar.c();
        for (Map.Entry<Integer, List<String>> entry : map.entrySet()) {
            int intValue = entry.getKey().intValue();
            boolean z = true;
            e.a(intValue >= 0 && intValue < this.c.length, "Invalid axis id: " + intValue);
            e.a(entry.getValue(), "Label list is null on axis " + intValue);
            if (this.c[intValue] != entry.getValue().size()) {
                z = false;
            }
            e.a(z, "Label number " + entry.getValue().size() + " mismatch the shape on axis " + intValue);
        }
    }

    private static int a(@NonNull org.tensorflow.lite.support.c.a aVar) {
        int[] c = aVar.c();
        for (int i = 0; i < c.length; i++) {
            if (c[i] > 1) {
                return i;
            }
        }
        throw new IllegalArgumentException("Cannot find an axis to label. A valid axis to label should have size larger than 1.");
    }

    private static Map<Integer, List<String>> a(int i, List<String> list) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put(Integer.valueOf(i), list);
        return linkedHashMap;
    }

    @NonNull
    public Map<String, org.tensorflow.lite.support.c.a> a() {
        int a = a(this.b);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        e.a(this.a.containsKey(Integer.valueOf(a)), "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
        List<String> list = this.a.get(Integer.valueOf(a));
        DataType d = this.b.d();
        int g = this.b.g();
        int b = this.b.b();
        ByteBuffer a2 = this.b.a();
        a2.rewind();
        int i = (b / this.c[a]) * g;
        e.a(list, "Label list should never be null");
        int i2 = 0;
        for (String str : list) {
            a2.position(i2 * i);
            ByteBuffer slice = a2.slice();
            slice.order(a2.order()).limit(i);
            org.tensorflow.lite.support.c.a a3 = org.tensorflow.lite.support.c.a.a(d);
            int[] iArr = this.c;
            a3.a(slice, Arrays.copyOfRange(iArr, a + 1, iArr.length));
            linkedHashMap.put(str, a3);
            i2++;
        }
        return linkedHashMap;
    }

    @NonNull
    public Map<String, Float> b() {
        int a = a(this.b);
        int i = 0;
        e.b(a == this.c.length - 1, "get a <String, Scalar> map is only valid when the only labeled axis is the last one.");
        List<String> list = this.a.get(Integer.valueOf(a));
        float[] e = this.b.e();
        e.b(list.size() == e.length);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            linkedHashMap.put(it.next(), Float.valueOf(e[i]));
            i++;
        }
        return linkedHashMap;
    }

    @NonNull
    public List<a> c() {
        int a = a(this.b);
        int i = 0;
        e.b(a == this.c.length - 1, "get a Category list is only valid when the only labeled axis is the last one.");
        List<String> list = this.a.get(Integer.valueOf(a));
        float[] e = this.b.e();
        e.b(list.size() == e.length);
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new a(it.next(), e[i]));
            i++;
        }
        return arrayList;
    }
}
