/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.TransformStatistics;

public class ColumnEncoderPassThrough
extends ColumnEncoder {
    private static final long serialVersionUID = -8473768154646831882L;

    protected ColumnEncoderPassThrough(int ptCols) {
        super(ptCols);
    }

    public ColumnEncoderPassThrough() {
        this(-1);
    }

    @Override
    public void build(CacheBlock<?> in) {
    }

    @Override
    public List<DependencyTask<?>> getBuildTasks(CacheBlock<?> in) {
        return null;
    }

    @Override
    protected ColumnEncoder.ColumnApplyTask<? extends ColumnEncoder> getSparseTask(CacheBlock<?> in, MatrixBlock out, int outputCol, int startRow, int blk) {
        return new PassThroughSparseApplyTask(this, in, out, outputCol, startRow, blk);
    }

    @Override
    protected double getCode(CacheBlock<?> in, int row) {
        return in.getDoubleNaN(row, this._colID - 1);
    }

    @Override
    protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) {
        try {
            int endLength = endInd - startInd;
            double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength];
            for (int i = startInd; i < endInd; ++i) {
                codes[i - startInd] = in.getDoubleNaN(i, this._colID - 1);
            }
            return codes;
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to get code for col: " + this._colID + " on range: " + startInd + "->" + endInd, e);
        }
    }

    @Override
    protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) {
        SparseBlock sb = out.getSparseBlock();
        boolean mcsr = sb instanceof SparseBlockMCSR;
        int index = this._colID - 1;
        int rowEnd = UtilFunctions.getEndIndex(in.getNumRows(), rowStart, blk);
        int bs = 32;
        double[] tmp = null;
        for (int i = rowStart; i < rowEnd; i += 32) {
            int end = Math.min(i + 32, rowEnd);
            tmp = this.getCodeCol(in, i, end, tmp);
            if (mcsr) {
                this.applySparseBlockMCSR(in, (SparseBlockMCSR)sb, index, outputCol, i, end, tmp);
                continue;
            }
            this.applySparseBlockCSR(in, (SparseBlockCSR)sb, index, outputCol, i, end, tmp);
        }
    }

    private void applySparseBlockMCSR(CacheBlock<?> in, SparseBlockMCSR sb, int index, int outputCol, int rl, int ru, double[] tmpCodes) {
        for (int i = rl; i < ru; ++i) {
            double v = tmpCodes[i - rl];
            SparseRowVector row = (SparseRowVector)sb.get(i);
            row.indexes()[index] = outputCol;
            if (v == 0.0) {
                this.containsZeroOut = true;
                continue;
            }
            row.values()[index] = v;
        }
    }

    private void applySparseBlockCSR(CacheBlock<?> in, SparseBlockCSR sb, int index, int outputCol, int rl, int ru, double[] tmpCodes) {
        int[] rptr = sb.rowPointers();
        int[] idx = sb.indexes();
        double[] val = sb.values();
        for (int i = rl; i < ru; ++i) {
            double v = tmpCodes[i - rl];
            idx[rptr[i] + index] = outputCol;
            if (v == 0.0) {
                this.containsZeroOut = true;
                continue;
            }
            val[rptr[i] + index] = v;
        }
    }

    @Override
    protected ColumnEncoder.TransformType getTransformType() {
        return ColumnEncoder.TransformType.PASS_THROUGH;
    }

    @Override
    public void mergeAt(ColumnEncoder other) {
        if (other instanceof ColumnEncoderPassThrough) {
            return;
        }
        super.mergeAt(other);
    }

    @Override
    public void allocateMetaData(FrameBlock meta) {
    }

    @Override
    public FrameBlock getMetaData(FrameBlock meta) {
        return meta;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.getClass().getSimpleName());
        sb.append(": ");
        sb.append(this._colID);
        return sb.toString();
    }

    public static class PassThroughSparseApplyTask
    extends ColumnEncoder.ColumnApplyTask<ColumnEncoderPassThrough> {
        protected PassThroughSparseApplyTask(ColumnEncoderPassThrough encoder, CacheBlock<?> input, MatrixBlock out, int outputCol, int startRow, int blk) {
            super(encoder, input, out, outputCol, startRow, blk);
        }

        @Override
        public Object call() throws Exception {
            if (this._out.getSparseBlock() == null) {
                return null;
            }
            long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            ((ColumnEncoderPassThrough)this._encoder).applySparse(this._input, this._out, this._outputCol, this._startRow, this._blk);
            if (DMLScript.STATISTICS) {
                TransformStatistics.incPassThroughApplyTime(System.nanoTime() - t0);
            }
            return null;
        }

        @Override
        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + ((ColumnEncoderPassThrough)this._encoder)._colID + ">";
        }
    }
}

