/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import jcuda.Pointer;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.jcusolver.JCusolverDn;
import jcuda.jcusolver.cusolverDnHandle;
import jcuda.jcusparse.JCusparse;
import jcuda.jcusparse.cusparseHandle;
import jcuda.jcusparse.cusparseMatDescr;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysml.runtime.matrix.data.CudaSupportFunctions;
import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysml.runtime.matrix.data.LibMatrixNative;
import org.apache.sysml.utils.GPUStatistics;

public class SinglePrecisionCudaSupportFunctions
implements CudaSupportFunctions {
    private static final Log LOG = LogFactory.getLog(SinglePrecisionCudaSupportFunctions.class.getName());

    @Override
    public int cusparsecsrgemm(cusparseHandle handle, int transA, int transB, int m, int n, int k, cusparseMatDescr descrA, int nnzA, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA, cusparseMatDescr descrB, int nnzB, Pointer csrValB, Pointer csrRowPtrB, Pointer csrColIndB, cusparseMatDescr descrC, Pointer csrValC, Pointer csrRowPtrC, Pointer csrColIndC) {
        return JCusparse.cusparseScsrgemm((cusparseHandle)handle, (int)transA, (int)transB, (int)m, (int)n, (int)k, (cusparseMatDescr)descrA, (int)nnzA, (Pointer)csrValA, (Pointer)csrRowPtrA, (Pointer)csrColIndA, (cusparseMatDescr)descrB, (int)nnzB, (Pointer)csrValB, (Pointer)csrRowPtrB, (Pointer)csrColIndB, (cusparseMatDescr)descrC, (Pointer)csrValC, (Pointer)csrRowPtrC, (Pointer)csrColIndC);
    }

    @Override
    public int cublasgeam(cublasHandle handle, int transa, int transb, int m, int n, Pointer alpha, Pointer A, int lda, Pointer beta, Pointer B, int ldb, Pointer C, int ldc) {
        return JCublas2.cublasSgeam((cublasHandle)handle, (int)transa, (int)transb, (int)m, (int)n, (Pointer)alpha, (Pointer)A, (int)lda, (Pointer)beta, (Pointer)B, (int)ldb, (Pointer)C, (int)ldc);
    }

    @Override
    public int cusparsecsrmv(cusparseHandle handle, int transA, int m, int n, int nnz, Pointer alpha, cusparseMatDescr descrA, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA, Pointer x, Pointer beta, Pointer y) {
        return JCusparse.cusparseScsrmv((cusparseHandle)handle, (int)transA, (int)m, (int)n, (int)nnz, (Pointer)alpha, (cusparseMatDescr)descrA, (Pointer)csrValA, (Pointer)csrRowPtrA, (Pointer)csrColIndA, (Pointer)x, (Pointer)beta, (Pointer)y);
    }

    @Override
    public int cusparsecsrmm2(cusparseHandle handle, int transa, int transb, int m, int n, int k, int nnz, Pointer alpha, cusparseMatDescr descrA, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA, Pointer B, int ldb, Pointer beta, Pointer C, int ldc) {
        return JCusparse.cusparseScsrmm2((cusparseHandle)handle, (int)transa, (int)transb, (int)m, (int)n, (int)k, (int)nnz, (Pointer)alpha, (cusparseMatDescr)descrA, (Pointer)csrValA, (Pointer)csrRowPtrA, (Pointer)csrColIndA, (Pointer)B, (int)ldb, (Pointer)beta, (Pointer)C, (int)ldc);
    }

    @Override
    public int cublasdot(cublasHandle handle, int n, Pointer x, int incx, Pointer y, int incy, Pointer result) {
        return JCublas2.cublasSdot((cublasHandle)handle, (int)n, (Pointer)x, (int)incx, (Pointer)y, (int)incy, (Pointer)result);
    }

    @Override
    public int cublasgemv(cublasHandle handle, int trans, int m, int n, Pointer alpha, Pointer A, int lda, Pointer x, int incx, Pointer beta, Pointer y, int incy) {
        return JCublas2.cublasSgemv((cublasHandle)handle, (int)trans, (int)m, (int)n, (Pointer)alpha, (Pointer)A, (int)lda, (Pointer)x, (int)incx, (Pointer)beta, (Pointer)y, (int)incy);
    }

    @Override
    public int cublasgemm(cublasHandle handle, int transa, int transb, int m, int n, int k, Pointer alpha, Pointer A, int lda, Pointer B, int ldb, Pointer beta, Pointer C, int ldc) {
        return JCublas2.cublasSgemm((cublasHandle)handle, (int)transa, (int)transb, (int)m, (int)n, (int)k, (Pointer)alpha, (Pointer)A, (int)lda, (Pointer)B, (int)ldb, (Pointer)beta, (Pointer)C, (int)ldc);
    }

    @Override
    public int cusparsecsr2csc(cusparseHandle handle, int m, int n, int nnz, Pointer csrVal, Pointer csrRowPtr, Pointer csrColInd, Pointer cscVal, Pointer cscRowInd, Pointer cscColPtr, int copyValues, int idxBase) {
        return JCusparse.cusparseScsr2csc((cusparseHandle)handle, (int)m, (int)n, (int)nnz, (Pointer)csrVal, (Pointer)csrRowPtr, (Pointer)csrColInd, (Pointer)cscVal, (Pointer)cscRowInd, (Pointer)cscColPtr, (int)copyValues, (int)idxBase);
    }

    @Override
    public int cublassyrk(cublasHandle handle, int uplo, int trans, int n, int k, Pointer alpha, Pointer A, int lda, Pointer beta, Pointer C, int ldc) {
        return JCublas2.cublasSsyrk((cublasHandle)handle, (int)uplo, (int)trans, (int)n, (int)k, (Pointer)alpha, (Pointer)A, (int)lda, (Pointer)beta, (Pointer)C, (int)ldc);
    }

    @Override
    public int cublasaxpy(cublasHandle handle, int n, Pointer alpha, Pointer x, int incx, Pointer y, int incy) {
        return JCublas2.cublasSaxpy((cublasHandle)handle, (int)n, (Pointer)alpha, (Pointer)x, (int)incx, (Pointer)y, (int)incy);
    }

    @Override
    public int cublastrsm(cublasHandle handle, int side, int uplo, int trans, int diag, int m, int n, Pointer alpha, Pointer A, int lda, Pointer B, int ldb) {
        return JCublas2.cublasStrsm((cublasHandle)handle, (int)side, (int)uplo, (int)trans, (int)diag, (int)m, (int)n, (Pointer)alpha, (Pointer)A, (int)lda, (Pointer)B, (int)ldb);
    }

    @Override
    public int cusolverDngeqrf_bufferSize(cusolverDnHandle handle, int m, int n, Pointer A, int lda, int[] Lwork) {
        return JCusolverDn.cusolverDnSgeqrf_bufferSize((cusolverDnHandle)handle, (int)m, (int)n, (Pointer)A, (int)lda, (int[])Lwork);
    }

    @Override
    public int cusolverDngeqrf(cusolverDnHandle handle, int m, int n, Pointer A, int lda, Pointer TAU, Pointer Workspace, int Lwork, Pointer devInfo) {
        return JCusolverDn.cusolverDnSgeqrf((cusolverDnHandle)handle, (int)m, (int)n, (Pointer)A, (int)lda, (Pointer)TAU, (Pointer)Workspace, (int)Lwork, (Pointer)devInfo);
    }

    @Override
    public int cusolverDnormqr(cusolverDnHandle handle, int side, int trans, int m, int n, int k, Pointer A, int lda, Pointer tau, Pointer C, int ldc, Pointer work, int lwork, Pointer devInfo) {
        return JCusolverDn.cusolverDnSormqr((cusolverDnHandle)handle, (int)side, (int)trans, (int)m, (int)n, (int)k, (Pointer)A, (int)lda, (Pointer)tau, (Pointer)C, (int)ldc, (Pointer)work, (int)lwork, (Pointer)devInfo);
    }

    @Override
    public int cusparsecsrgeam(cusparseHandle handle, int m, int n, Pointer alpha, cusparseMatDescr descrA, int nnzA, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA, Pointer beta, cusparseMatDescr descrB, int nnzB, Pointer csrValB, Pointer csrRowPtrB, Pointer csrColIndB, cusparseMatDescr descrC, Pointer csrValC, Pointer csrRowPtrC, Pointer csrColIndC) {
        return JCusparse.cusparseScsrgeam((cusparseHandle)handle, (int)m, (int)n, (Pointer)alpha, (cusparseMatDescr)descrA, (int)nnzA, (Pointer)csrValA, (Pointer)csrRowPtrA, (Pointer)csrColIndA, (Pointer)beta, (cusparseMatDescr)descrB, (int)nnzB, (Pointer)csrValB, (Pointer)csrRowPtrB, (Pointer)csrColIndB, (cusparseMatDescr)descrC, (Pointer)csrValC, (Pointer)csrRowPtrC, (Pointer)csrColIndC);
    }

    @Override
    public int cusparsecsr2dense(cusparseHandle handle, int m, int n, cusparseMatDescr descrA, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA, Pointer A, int lda) {
        return JCusparse.cusparseScsr2dense((cusparseHandle)handle, (int)m, (int)n, (cusparseMatDescr)descrA, (Pointer)csrValA, (Pointer)csrRowPtrA, (Pointer)csrColIndA, (Pointer)A, (int)lda);
    }

    @Override
    public int cusparsedense2csr(cusparseHandle handle, int m, int n, cusparseMatDescr descrA, Pointer A, int lda, Pointer nnzPerRow, Pointer csrValA, Pointer csrRowPtrA, Pointer csrColIndA) {
        return JCusparse.cusparseSdense2csr((cusparseHandle)handle, (int)m, (int)n, (cusparseMatDescr)descrA, (Pointer)A, (int)lda, (Pointer)nnzPerRow, (Pointer)csrValA, (Pointer)csrRowPtrA, (Pointer)csrColIndA);
    }

    @Override
    public int cusparsennz(cusparseHandle handle, int dirA, int m, int n, cusparseMatDescr descrA, Pointer A, int lda, Pointer nnzPerRowCol, Pointer nnzTotalDevHostPtr) {
        return JCusparse.cusparseSnnz((cusparseHandle)handle, (int)dirA, (int)m, (int)n, (cusparseMatDescr)descrA, (Pointer)A, (int)lda, (Pointer)nnzPerRowCol, (Pointer)nnzTotalDevHostPtr);
    }

    @Override
    public void deviceToHost(GPUContext gCtx, Pointer src, double[] dest, String instName, boolean isEviction) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (!isEviction) {
            Pointer deviceDoubleData = gCtx.allocate(instName, (long)dest.length * 8L);
            LibMatrixCUDA.float2double(gCtx, src, deviceDoubleData, dest.length);
            JCuda.cudaMemcpy((Pointer)Pointer.to((double[])dest), (Pointer)deviceDoubleData, (long)((long)dest.length * 8L), (int)2);
            gCtx.cudaFreeHelper(instName, deviceDoubleData, DMLScript.EAGER_CUDA_FREE);
        } else {
            LOG.debug("Potential OOM: Allocated additional space on host in deviceToHost");
            FloatBuffer floatData = ByteBuffer.allocateDirect(4 * dest.length).order(ByteOrder.nativeOrder()).asFloatBuffer();
            JCuda.cudaMemcpy((Pointer)Pointer.to((Buffer)floatData), (Pointer)src, (long)((long)dest.length * 4L), (int)2);
            LibMatrixNative.fromFloatBuffer(floatData, dest);
        }
        if (DMLScript.STATISTICS) {
            long totalTime = System.nanoTime() - t0;
            GPUStatistics.cudaFloat2DoubleTime.add(totalTime);
            GPUStatistics.cudaFloat2DoubleCount.add(1L);
            if (DMLScript.FINEGRAINED_STATISTICS && instName != null) {
                GPUStatistics.maintainCPMiscTimes(instName, "D2H", totalTime);
            }
        }
    }

    @Override
    public void hostToDevice(GPUContext gCtx, double[] src, Pointer dest, String instName) {
        LOG.debug("Potential OOM: Allocated additional space in hostToDevice");
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        Pointer deviceDoubleData = gCtx.allocate(instName, (long)src.length * 8L);
        JCuda.cudaMemcpy((Pointer)deviceDoubleData, (Pointer)Pointer.to((double[])src), (long)((long)src.length * 8L), (int)1);
        LibMatrixCUDA.double2float(gCtx, deviceDoubleData, dest, src.length);
        gCtx.cudaFreeHelper(instName, deviceDoubleData, DMLScript.EAGER_CUDA_FREE);
        if (DMLScript.STATISTICS) {
            long totalTime = System.nanoTime() - t0;
            GPUStatistics.cudaDouble2FloatTime.add(totalTime);
            GPUStatistics.cudaDouble2FloatCount.add(1L);
            if (DMLScript.FINEGRAINED_STATISTICS && instName != null) {
                GPUStatistics.maintainCPMiscTimes(instName, "H2D", totalTime);
            }
        }
    }

    private static /* synthetic */ void lambda$hostToDevice$0(FloatBuffer floatData, double[] src, int i) {
        floatData.put(i, (float)src[i]);
    }
}

