package com.mayabot.nlp.fasttext.train;

import com.github.mikephil.charting.utils.Utils;
import com.lzy.okserver.download.DownloadInfo;
import com.mayabot.nlp.common.IntArrayList;
import com.mayabot.nlp.fasttext.FastText;
import com.mayabot.nlp.fasttext.Model;
import com.mayabot.nlp.fasttext.args.Args;
import com.mayabot.nlp.fasttext.args.ModelName;
import com.mayabot.nlp.fasttext.dictionary.Dictionary;
import com.mayabot.nlp.fasttext.loss.LossName;
import com.mayabot.nlp.fasttext.utils.LogUtilsKt;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import kotlin.Metadata;
import kotlin.jvm.internal.Intrinsics;
import kotlin.jvm.internal.StringCompanionObject;

/* compiled from: FastTextTrain.kt */
@Metadata(bv = {1, 0, 3}, d1 = {"\u0000f\n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0002\b\u0005\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\t\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0006\n\u0002\u0010\u000b\n\u0000\n\u0002\u0010\u0002\n\u0000\n\u0002\u0010\u0007\n\u0002\b\u0003\n\u0002\u0010 \n\u0002\u0010\u001c\n\u0002\u0018\u0002\n\u0002\b\u0003\u0018\u00002\u00020\u0001:\u0002./B\u0015\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\u0006\u0010\u0004\u001a\u00020\u0005¢\u0006\u0002\u0010\u0006J\b\u0010\"\u001a\u00020#H\u0002J \u0010$\u001a\u00020%2\u0006\u0010&\u001a\u00020'2\u0006\u0010\u0010\u001a\u00020\u00112\u0006\u0010(\u001a\u00020#H\u0002J\b\u0010&\u001a\u00020'H\u0002J\u001a\u0010)\u001a\u00020%2\u0012\u0010*\u001a\u000e\u0012\n\u0012\b\u0012\u0004\u0012\u00020-0,0+R\u0011\u0010\u0007\u001a\u00020\u0003¢\u0006\b\n\u0000\u001a\u0004\b\b\u0010\tR\u0011\u0010\n\u001a\u00020\u000b¢\u0006\b\n\u0000\u001a\u0004\b\f\u0010\rR\u0011\u0010\u0004\u001a\u00020\u0005¢\u0006\b\n\u0000\u001a\u0004\b\u000e\u0010\u000fR\u000e\u0010\u0010\u001a\u00020\u0011X\u0082\u0004¢\u0006\u0002\n\u0000R\u0011\u0010\u0012\u001a\u00020\u0013¢\u0006\b\n\u0000\u001a\u0004\b\u0014\u0010\u0015R\u000e\u0010\u0016\u001a\u00020\u0013X\u0082\u000e¢\u0006\u0002\n\u0000R\u000e\u0010\u0017\u001a\u00020\u0018X\u0082\u0004¢\u0006\u0002\n\u0000R\u0011\u0010\u0002\u001a\u00020\u0003¢\u0006\b\n\u0000\u001a\u0004\b\u0019\u0010\tR\"\u0010\u001a\u001a\n\u0018\u00010\u001bj\u0004\u0018\u0001`\u001cX\u0086\u000e¢\u0006\u000e\n\u0000\u001a\u0004\b\u001d\u0010\u001e\"\u0004\b\u001f\u0010 R\u000e\u0010!\u001a\u00020\u0013X\u0082\u0004¢\u0006\u0002\n\u0000¨\u00060"}, d2 = {"Lcom/mayabot/nlp/fasttext/train/FastTextTrain;", "", "trainArgs", "Lcom/mayabot/nlp/fasttext/args/Args;", "fastText", "Lcom/mayabot/nlp/fasttext/FastText;", "(Lcom/mayabot/nlp/fasttext/args/Args;Lcom/mayabot/nlp/fasttext/FastText;)V", "args", "getArgs", "()Lcom/mayabot/nlp/fasttext/args/Args;", "dict", "Lcom/mayabot/nlp/fasttext/dictionary/Dictionary;", "getDict", "()Lcom/mayabot/nlp/fasttext/dictionary/Dictionary;", "getFastText", "()Lcom/mayabot/nlp/fasttext/FastText;", "loss", "Lcom/mayabot/nlp/fasttext/train/FastTextTrain$ShareDouble;", "ntokens", "", "getNtokens", "()J", "startTime", "tokenCount", "Ljava/util/concurrent/atomic/AtomicLong;", "getTrainArgs", "trainException", "Ljava/lang/Exception;", "Lkotlin/Exception;", "getTrainException", "()Ljava/lang/Exception;", "setTrainException", "(Ljava/lang/Exception;)V", "wantProcessTotalTokens", "keepTraining", "", "printInfo", "", "progress", "", "stop", "startThreads", "sources", "", "", "Lcom/mayabot/nlp/fasttext/train/SampleLine;", "ShareDouble", "TrainThread", "mynlp"}, k = 1, mv = {1, 4, 1})
/* loaded from: classes.dex */
public final class FastTextTrain {
    private final Args args;
    private final Dictionary dict;
    private final FastText fastText;
    private final ShareDouble loss;
    private final long ntokens;
    private long startTime;
    private final AtomicLong tokenCount;
    private final Args trainArgs;
    private Exception trainException;
    private final long wantProcessTotalTokens;

    /* compiled from: FastTextTrain.kt */
    @Metadata(bv = {1, 0, 3}, d1 = {"\u0000 \n\u0002\u0018\u0002\n\u0002\u0010\u0000\n\u0000\n\u0002\u0010\u0006\n\u0002\b\u0005\n\u0002\u0010\u0002\n\u0002\b\u0002\n\u0002\u0010\u0007\n\u0000\u0018\u00002\u00020\u0001B\r\u0012\u0006\u0010\u0002\u001a\u00020\u0003¢\u0006\u0002\u0010\u0004J\u000e\u0010\b\u001a\u00020\t2\u0006\u0010\n\u001a\u00020\u0003J\u0006\u0010\u000b\u001a\u00020\fR\u001a\u0010\u0002\u001a\u00020\u0003X\u0086\u000e¢\u0006\u000e\n\u0000\u001a\u0004\b\u0005\u0010\u0006\"\u0004\b\u0007\u0010\u0004¨\u0006\r"}, d2 = {"Lcom/mayabot/nlp/fasttext/train/FastTextTrain$ShareDouble;", "", "value", "", "(D)V", "getValue", "()D", "setValue", "set", "", "v", "toFloat", "", "mynlp"}, k = 1, mv = {1, 4, 1})
    /* loaded from: classes.dex */
    public static final class ShareDouble {
        private double value;

        public ShareDouble(double d) {
            this.value = d;
        }

        public final double getValue() {
            return this.value;
        }

        public final void set(double v) {
            this.value = v;
        }

        public final void setValue(double d) {
            this.value = d;
        }

        public final float toFloat() {
            return (float) this.value;
        }
    }

    /* compiled from: FastTextTrain.kt */
    @Metadata(bv = {1, 0, 3}, d1 = {"\u0000F\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\b\n\u0000\n\u0002\u0010\u001c\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0010\t\n\u0002\b\u0003\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u0002\n\u0000\n\u0002\u0018\u0002\n\u0000\n\u0002\u0010\u0007\n\u0000\n\u0002\u0018\u0002\n\u0002\b\u0005\b\u0080\u0004\u0018\u00002\u00020\u0001B\u001b\u0012\u0006\u0010\u0002\u001a\u00020\u0003\u0012\f\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005¢\u0006\u0002\u0010\u0007J(\u0010\u0015\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001cH\u0002J\b\u0010\u001d\u001a\u00020\u0016H\u0016J(\u0010\u001e\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001cH\u0002J0\u0010\u001f\u001a\u00020\u00162\u0006\u0010\u0011\u001a\u00020\u00122\u0006\u0010\u0017\u001a\u00020\u00182\u0006\u0010\u0019\u001a\u00020\u001a2\u0006\u0010\u001b\u001a\u00020\u001c2\u0006\u0010 \u001a\u00020\u001cH\u0002R\u001a\u0010\b\u001a\u00020\u0003X\u0086\u000e¢\u0006\u000e\n\u0000\u001a\u0004\b\t\u0010\n\"\u0004\b\u000b\u0010\fR\u0011\u0010\r\u001a\u00020\u000e¢\u0006\b\n\u0000\u001a\u0004\b\u000f\u0010\u0010R\u0014\u0010\u0004\u001a\b\u0012\u0004\u0012\u00020\u00060\u0005X\u0082\u0004¢\u0006\u0002\n\u0000R\u0011\u0010\u0011\u001a\u00020\u0012¢\u0006\b\n\u0000\u001a\u0004\b\u0013\u0010\u0014R\u000e\u0010\u0002\u001a\u00020\u0003X\u0082\u0004¢\u0006\u0002\n\u0000¨\u0006!"}, d2 = {"Lcom/mayabot/nlp/fasttext/train/FastTextTrain$TrainThread;", "Ljava/lang/Runnable;", "threadId", "", "parts", "", "Lcom/mayabot/nlp/fasttext/train/SampleLine;", "(Lcom/mayabot/nlp/fasttext/train/FastTextTrain;ILjava/lang/Iterable;)V", "localTokenCount", "getLocalTokenCount", "()I", "setLocalTokenCount", "(I)V", "ntokens", "", "getNtokens", "()J", DownloadInfo.STATE, "Lcom/mayabot/nlp/fasttext/Model$State;", "getState", "()Lcom/mayabot/nlp/fasttext/Model$State;", "cbow", "", "model", "Lcom/mayabot/nlp/fasttext/Model;", "lr", "", "line", "Lcom/mayabot/nlp/common/IntArrayList;", "run", "skipgram", "supervised", "labels", "mynlp"}, k = 1, mv = {1, 4, 1})
    /* loaded from: classes.dex */
    public final class TrainThread implements Runnable {
        private int localTokenCount;
        private final long ntokens;
        private final Iterable<SampleLine> parts;
        private final Model.State state;
        final /* synthetic */ FastTextTrain this$0;
        private final int threadId;

        @Metadata(bv = {1, 0, 3}, k = 3, mv = {1, 4, 1})
        /* loaded from: classes.dex */
        public final /* synthetic */ class WhenMappings {
            public static final /* synthetic */ int[] $EnumSwitchMapping$0;

            static {
                int[] iArr = new int[ModelName.values().length];
                $EnumSwitchMapping$0 = iArr;
                iArr[ModelName.sup.ordinal()] = 1;
                iArr[ModelName.cbow.ordinal()] = 2;
                iArr[ModelName.sg.ordinal()] = 3;
            }
        }

        public TrainThread(FastTextTrain fastTextTrain, int i, Iterable<SampleLine> parts) {
            Intrinsics.checkNotNullParameter(parts, "parts");
            this.this$0 = fastTextTrain;
            this.threadId = i;
            this.parts = parts;
            this.state = new Model.State(fastTextTrain.getArgs().getDim(), fastTextTrain.getFastText().getOutput().getRow(), fastTextTrain.getTrainArgs().getSeed());
            this.ntokens = fastTextTrain.getDict().getNtokens();
        }

        private final void cbow(Model.State state, Model model, float lr, IntArrayList line) {
            IntArrayList intArrayList = new IntArrayList(0, null, 3, null);
            int elementsCount = line.getElementsCount();
            for (int i = 0; i < elementsCount; i++) {
                int nextInt = state.getRng().nextInt(this.this$0.getArgs().getWs()) + 1;
                intArrayList.clear();
                int i2 = -nextInt;
                if (i2 <= nextInt) {
                    while (true) {
                        if (i2 != 0) {
                            int i3 = i + i2;
                            if (i3 >= 0 && i3 < line.getElementsCount()) {
                                intArrayList.addAll(this.this$0.getDict().getSubwords(line.get(i3)));
                            }
                        }
                        if (i2 != nextInt) {
                            i2++;
                        }
                    }
                }
                model.update(intArrayList, line, i, lr, state);
            }
        }

        private final void skipgram(Model.State state, Model model, float lr, IntArrayList line) {
            int elementsCount = line.getElementsCount();
            for (int i = 0; i < elementsCount; i++) {
                int nextInt = state.getRng().nextInt(this.this$0.getArgs().getWs()) + 1;
                IntArrayList subwords = this.this$0.getDict().getSubwords(line.get(i));
                int i2 = -nextInt;
                if (i2 <= nextInt) {
                    while (true) {
                        if (i2 != 0) {
                            int i3 = i + i2;
                            if (i3 >= 0 && i3 < line.getElementsCount()) {
                                model.update(subwords, line, i3, lr, state);
                            }
                        }
                        if (i2 != nextInt) {
                            i2++;
                        }
                    }
                }
            }
        }

        private final void supervised(Model.State state, Model model, float lr, IntArrayList line, IntArrayList labels) {
            if (labels.getElementsCount() == 0 || line.getElementsCount() == 0) {
                return;
            }
            if (this.this$0.getArgs().getLoss() == LossName.ova) {
                model.update(line, labels, Model.INSTANCE.getKAllLabelsAsTarget(), lr, state);
            } else {
                model.update(line, labels, state.getRng().nextInt(labels.getElementsCount()), lr, state);
            }
        }

        public final int getLocalTokenCount() {
            return this.localTokenCount;
        }

        public final long getNtokens() {
            return this.ntokens;
        }

        public final Model.State getState() {
            return this.state;
        }

        /* JADX WARN: Code restructure failed: missing block: B:41:0x011d, code lost:
        
            throw new java.lang.IllegalStateException("不可能为空".toString());
         */
        @Override // java.lang.Runnable
        /*
            Code decompiled incorrectly, please refer to instructions dump.
            To view partially-correct add '--show-bad-code' argument
        */
        public void run() {
            /*
                Method dump skipped, instructions count: 293
                To view this dump add '--comments-level debug' option
            */
            throw new UnsupportedOperationException("Method not decompiled: com.mayabot.nlp.fasttext.train.FastTextTrain.TrainThread.run():void");
        }

        public final void setLocalTokenCount(int i) {
            this.localTokenCount = i;
        }
    }

    public FastTextTrain(Args trainArgs, FastText fastText) {
        Intrinsics.checkNotNullParameter(trainArgs, "trainArgs");
        Intrinsics.checkNotNullParameter(fastText, "fastText");
        this.trainArgs = trainArgs;
        this.fastText = fastText;
        this.tokenCount = new AtomicLong(0L);
        this.loss = new ShareDouble(-1.0d);
        this.startTime = System.currentTimeMillis();
        Dictionary dict = fastText.getDict();
        this.dict = dict;
        this.args = trainArgs;
        long ntokens = dict.getNtokens();
        this.ntokens = ntokens;
        this.wantProcessTotalTokens = trainArgs.getEpoch() * ntokens;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final boolean keepTraining() {
        return this.tokenCount.longValue() < this.wantProcessTotalTokens && this.trainException == null;
    }

    private final void printInfo(float progress, ShareDouble loss, boolean stop) {
        double d;
        float f = progress;
        double currentTimeMillis = (System.currentTimeMillis() - this.startTime) / 1000;
        double lr = this.trainArgs.getLr() * (1.0d - f);
        long j = 2592000;
        if (f <= 0 || currentTimeMillis < 0) {
            d = Utils.DOUBLE_EPSILON;
        } else {
            f *= 100;
            j = (long) (((100.0f - f) * currentTimeMillis) / f);
            d = (this.tokenCount.doubleValue() / currentTimeMillis) / this.trainArgs.getThread();
        }
        long j2 = 3600;
        long j3 = j / j2;
        long j4 = j % j2;
        long j5 = 60;
        long j6 = j4 / j5;
        long j7 = j4 % j5;
        StringBuilder sb = new StringBuilder();
        StringBuilder sb2 = new StringBuilder();
        sb2.append("Progress: ");
        StringCompanionObject stringCompanionObject = StringCompanionObject.INSTANCE;
        String format = String.format("%2.2f", Arrays.copyOf(new Object[]{Float.valueOf(f)}, 1));
        Intrinsics.checkNotNullExpressionValue(format, "java.lang.String.format(format, *args)");
        sb2.append(format);
        sb2.append("% words/sec/thread: ");
        StringCompanionObject stringCompanionObject2 = StringCompanionObject.INSTANCE;
        String format2 = String.format("%8.0f", Arrays.copyOf(new Object[]{Double.valueOf(d)}, 1));
        Intrinsics.checkNotNullExpressionValue(format2, "java.lang.String.format(format, *args)");
        sb2.append(format2);
        sb.append(sb2.toString());
        if (!stop) {
            StringCompanionObject stringCompanionObject3 = StringCompanionObject.INSTANCE;
            String format3 = String.format(" lr: %2.5f", Arrays.copyOf(new Object[]{Double.valueOf(lr)}, 1));
            Intrinsics.checkNotNullExpressionValue(format3, "java.lang.String.format(format, *args)");
            sb.append(format3);
        }
        StringCompanionObject stringCompanionObject4 = StringCompanionObject.INSTANCE;
        String format4 = String.format(" arg.loss: %2.5f", Arrays.copyOf(new Object[]{Float.valueOf(loss.toFloat())}, 1));
        Intrinsics.checkNotNullExpressionValue(format4, "java.lang.String.format(format, *args)");
        sb.append(format4);
        if (!stop) {
            sb.append(" ETA: " + j3 + "h " + j6 + "m " + j7 + "s");
        }
        LogUtilsKt.logger(sb);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public final float progress() {
        return this.tokenCount.floatValue() / ((float) this.wantProcessTotalTokens);
    }

    public final Args getArgs() {
        return this.args;
    }

    public final Dictionary getDict() {
        return this.dict;
    }

    public final FastText getFastText() {
        return this.fastText;
    }

    public final long getNtokens() {
        return this.ntokens;
    }

    public final Args getTrainArgs() {
        return this.trainArgs;
    }

    public final Exception getTrainException() {
        return this.trainException;
    }

    public final void setTrainException(Exception exc) {
        this.trainException = exc;
    }

    public final void startThreads(List<? extends Iterable<SampleLine>> sources) {
        Intrinsics.checkNotNullParameter(sources, "sources");
        int size = sources.size();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            arrayList.add(new Thread(new TrainThread(this, i, sources.get(i))));
        }
        for (int i2 = 0; i2 < size; i2++) {
            ((Thread) arrayList.get(i2)).start();
        }
        this.dict.getNtokens();
        while (keepTraining()) {
            Thread.sleep(100L);
            if (this.loss.toFloat() >= 0) {
                float progress = progress();
                LogUtilsKt.logger("\r");
                printInfo(progress, this.loss, false);
            }
        }
        for (int i3 = 0; i3 < size; i3++) {
            ((Thread) arrayList.get(i3)).join();
        }
        Exception exc = this.trainException;
        if (exc != null) {
            throw exc;
        }
        LogUtilsKt.logger("\r");
        printInfo(1.0f, this.loss, true);
        LogUtilsKt.loggerln();
        LogUtilsKt.loggerln("Train use time " + (System.currentTimeMillis() - this.startTime) + " ms");
    }
}
