/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.impl;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.SplittableRandom;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.impl.SkeletalIndependentRegressionModel;

public abstract class SkeletalIndependentRegressionTrainer<T>
implements Trainer<Regressor> {
    @Config(description="Seed for the RNG, may be unused.")
    private long seed = 1L;
    private SplittableRandom rng;
    private int trainInvocationCounter = 0;

    protected SkeletalIndependentRegressionTrainer() {
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public SkeletalIndependentRegressionModel train(Dataset<Regressor> examples) {
        return this.train((Dataset)examples, Collections.emptyMap());
    }

    public SkeletalIndependentRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance) {
        return this.train((Dataset)examples, (Map)runProvenance, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public SkeletalIndependentRegressionModel train(Dataset<Regressor> examples, Map<String, Provenance> runProvenance, int invocationCount) {
        TrainerProvenance trainerProvenance;
        SplittableRandom localRNG;
        if (examples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        SkeletalIndependentRegressionTrainer skeletalIndependentRegressionTrainer = this;
        synchronized (skeletalIndependentRegressionTrainer) {
            if (invocationCount != -1) {
                this.setInvocationCount(invocationCount);
            }
            localRNG = this.rng.split();
            trainerProvenance = (TrainerProvenance)this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableOutputInfo outputInfo = examples.getOutputIDInfo();
        ImmutableFeatureMap featureMap = examples.getFeatureIDMap();
        Set domain = outputInfo.getDomain();
        LinkedHashMap<String, T> models = new LinkedHashMap<String, T>();
        int numExamples = examples.size();
        boolean needBias = this.useBias();
        float[] weights = new float[numExamples];
        double[][] outputs = new double[outputInfo.size()][numExamples];
        SparseVector[] inputs = new SparseVector[numExamples];
        int i = 0;
        for (Example e : examples) {
            inputs[i] = SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureMap, (boolean)needBias);
            weights[i] = e.getWeight();
            for (Regressor.DimensionTuple r : (Regressor)e.getOutput()) {
                int id = outputInfo.getID((Output)r);
                outputs[id][i] = r.getValue();
            }
            ++i;
        }
        for (Regressor r : domain) {
            int id = outputInfo.getID((Output)r);
            T innerModel = this.trainDimension(outputs[id], inputs, weights, localRNG);
            models.put(r.getNames()[0], innerModel);
        }
        ModelProvenance provenance = new ModelProvenance(this.getModelClassName(), OffsetDateTime.now(), (DatasetProvenance)examples.getProvenance(), trainerProvenance, runProvenance);
        return this.createModel(models, provenance, featureMap, (ImmutableOutputInfo<Regressor>)outputInfo);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int invocationCount) {
        if (invocationCount < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < invocationCount) {
            SplittableRandom splittableRandom = this.rng.split();
            ++this.trainInvocationCounter;
        }
    }

    protected abstract SkeletalIndependentRegressionModel createModel(Map<String, T> var1, ModelProvenance var2, ImmutableFeatureMap var3, ImmutableOutputInfo<Regressor> var4);

    protected abstract T trainDimension(double[] var1, SparseVector[] var2, float[] var3, SplittableRandom var4);

    protected abstract boolean useBias();

    protected abstract String getModelClassName();
}

