package com.hankcs.hanlp.dependency.nnparser;

import com.hankcs.hanlp.dependency.nnparser.option.LearnOption;
import com.hankcs.hanlp.dependency.nnparser.util.Log;
import java.util.List;
import java.util.Map;

/* loaded from: classes2.dex */
public class NeuralNetworkClassifier {
    Matrix E;
    Matrix W1;
    Matrix W2;
    double accuracy;
    double ada_alpha;
    double ada_eps;
    Matrix b1;
    int batch_size;
    double dropout_probability;
    Matrix eg2E;
    Matrix eg2W1;
    Matrix eg2W2;
    Matrix eg2b1;
    boolean fix_embeddings;
    Matrix grad_E;
    Matrix grad_W1;
    Matrix grad_W2;
    Matrix grad_b1;
    Matrix grad_saved;
    double lambda;
    double loss;
    int nr_threads;
    Map<Integer, Integer> precomputation_id_encoder;
    Matrix saved;
    boolean initialized = false;
    int embedding_size = 0;
    int hidden_layer_size = 0;
    int nr_objects = 0;
    int nr_feature_types = 0;
    int nr_classes = 0;

    /* JADX INFO: Access modifiers changed from: package-private */
    public NeuralNetworkClassifier(Matrix matrix, Matrix matrix2, Matrix matrix3, Matrix matrix4, Matrix matrix5, Map<Integer, Integer> map) {
        this.W1 = matrix;
        this.W2 = matrix2;
        this.E = matrix3;
        this.b1 = matrix4;
        this.saved = matrix5;
        this.precomputation_id_encoder = map;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void canonical() {
        this.hidden_layer_size = this.b1.rows();
        this.nr_feature_types = this.W1.cols() / this.E.rows();
        this.nr_classes = this.W2.rows();
        this.embedding_size = this.E.rows();
    }

    double get_accuracy() {
        return this.accuracy;
    }

    double get_cost() {
        return this.loss;
    }

    void info() {
        Log.INFO_LOG("classifier: E(%d,%d)", Integer.valueOf(this.E.rows()), Integer.valueOf(this.E.cols()));
        Log.INFO_LOG("classifier: W1(%d,%d)", Integer.valueOf(this.W1.rows()), Integer.valueOf(this.W1.cols()));
        Log.INFO_LOG("classifier: b1(%d)", Integer.valueOf(this.b1.rows()));
        Log.INFO_LOG("classifier: W2(%d,%d)", Integer.valueOf(this.W2.rows()), Integer.valueOf(this.W2.cols()));
        Log.INFO_LOG("classifier: saved(%d,%d)", Integer.valueOf(this.saved.rows()), Integer.valueOf(this.saved.cols()));
        Log.INFO_LOG("classifier: precomputed size=%d", Integer.valueOf(this.precomputation_id_encoder.size()));
        Log.INFO_LOG("classifier: hidden layer size=%d", Integer.valueOf(this.hidden_layer_size));
        Log.INFO_LOG("classifier: embedding size=%d", Integer.valueOf(this.embedding_size));
        Log.INFO_LOG("classifier: number of classes=%d", Integer.valueOf(this.nr_classes));
        Log.INFO_LOG("classifier: number of feature types=%d", Integer.valueOf(this.nr_feature_types));
    }

    void initialize(int i, int i2, int i3, LearnOption learnOption, List<List<Double>> list, List<Integer> list2) {
        if (this.initialized) {
            Log.ERROR_LOG("classifier: weight should not be initialized twice!", new Object[0]);
            return;
        }
        this.batch_size = learnOption.batch_size;
        this.fix_embeddings = learnOption.fix_embeddings;
        this.dropout_probability = learnOption.dropout_probability;
        this.lambda = learnOption.lambda;
        this.ada_eps = learnOption.ada_eps;
        this.ada_alpha = learnOption.ada_alpha;
        this.nr_feature_types = i3;
        this.nr_objects = i;
        this.nr_classes = i2;
        this.embedding_size = learnOption.embedding_size;
        int i4 = learnOption.hidden_layer_size;
        this.hidden_layer_size = i4;
        double d = 6.0d / (r0 + i4);
        this.W1 = Matrix.random(i4, this.embedding_size * this.nr_feature_types).times(Math.sqrt(d));
        this.b1 = Matrix.random(i4, 1).times(Math.sqrt(d));
        this.W2 = Matrix.random(i2, this.hidden_layer_size).times(Math.sqrt(6.0d / (i2 + r12)));
        this.E = Matrix.random(this.embedding_size, i).times(learnOption.init_range);
        for (int i5 = 0; i5 < list.size(); i5++) {
            List<Double> list3 = list.get(i5);
            int intValue = list3.get(0).intValue();
            for (int i6 = 1; i6 < list3.size(); i6++) {
                this.E.set(i6 - 1, intValue, list3.get(i6).doubleValue());
            }
        }
        this.grad_W1 = Matrix.zero(this.W1.getRowDimension(), this.W1.getColumnDimension());
        this.grad_b1 = Matrix.zero(this.b1.rows(), 1);
        this.grad_W2 = Matrix.zero(this.W2.rows(), this.W2.cols());
        this.grad_E = Matrix.zero(this.E.rows(), this.E.cols());
        Map<Integer, Integer> map = this.precomputation_id_encoder;
        int i7 = 0;
        int i8 = 0;
        while (i7 < list2.size()) {
            map.put(Integer.valueOf(list2.get(i7).intValue()), Integer.valueOf(i8));
            i7++;
            i8++;
        }
        this.saved = Matrix.zero(this.hidden_layer_size, map.size());
        this.grad_saved = Matrix.zero(this.hidden_layer_size, map.size());
        initialize_gradient_histories();
        this.initialized = true;
        info();
        Log.INFO_LOG("classifier: size of batch = %d", Integer.valueOf(this.batch_size));
        Log.INFO_LOG("classifier: alpha = %e", Double.valueOf(this.ada_alpha));
        Log.INFO_LOG("classifier: eps = %e", Double.valueOf(this.ada_eps));
        Log.INFO_LOG("classifier: lambda = %e", Double.valueOf(this.lambda));
        Object[] objArr = new Object[1];
        objArr[0] = this.fix_embeddings ? "true" : "false";
        Log.INFO_LOG("classifier: fix embedding = %s", objArr);
    }

    void initialize_gradient_histories() {
        this.eg2W1 = Matrix.zero(this.W1.rows(), this.W1.cols());
        this.eg2b1 = Matrix.zero(this.b1.rows(), 1);
        this.eg2W2 = Matrix.zero(this.W2.rows(), this.W2.cols());
        this.eg2E = Matrix.zero(this.E.rows(), this.E.cols());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void score(List<Integer> list, List<Double> list2) {
        Map<Integer, Integer> map = this.precomputation_id_encoder;
        Matrix zero = Matrix.zero(this.hidden_layer_size, 1);
        int i = 0;
        int i2 = 0;
        while (i < list.size()) {
            int intValue = list.get(i).intValue();
            Integer num = map.get(Integer.valueOf((this.nr_feature_types * intValue) + i));
            if (num != null) {
                zero.plusEquals(this.saved.col(num.intValue()));
            } else {
                zero.plusEquals(this.W1.block(0, i2, this.hidden_layer_size, this.embedding_size).times(this.E.col(intValue)));
            }
            i++;
            i2 += this.embedding_size;
        }
        zero.plusEquals(this.b1);
        Matrix times = this.W2.times(new Matrix(zero.cube()));
        list2.clear();
        for (int i3 = 0; i3 < this.nr_classes; i3++) {
            list2.add(Double.valueOf(times.get(i3, 0)));
        }
    }
}
