Add a patch for mfem v4.3 to support cusparse >= 11.4 (#27267)

This commit is contained in:
Veselin Dobrev 2021-11-09 18:29:06 -08:00 committed by GitHub
parent 723f2f465b
commit ceabb96c89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 82 additions and 0 deletions

View File

@ -0,0 +1,80 @@
diff --git a/linalg/sparsemat.cpp b/linalg/sparsemat.cpp
index 12136e035..0be73cf7b 100644
--- a/linalg/sparsemat.cpp
+++ b/linalg/sparsemat.cpp
@@ -33,7 +33,12 @@ int SparseMatrix::SparseMatrixCount = 0;
cusparseHandle_t SparseMatrix::handle = nullptr;
size_t SparseMatrix::bufferSize = 0;
void * SparseMatrix::dBuffer = nullptr;
-#endif
+# if CUSPARSE_VERSION >= 11400
+# define MFEM_CUSPARSE_ALG CUSPARSE_SPMV_CSR_ALG1
+# else
+# define MFEM_CUSPARSE_ALG CUSPARSE_CSRMV_ALG1
+# endif // CUSPARSE_VERSION >= 11400
+#endif // MFEM_USE_CUDA
void SparseMatrix::InitCuSparse()
{
@@ -679,25 +684,16 @@ void SparseMatrix::AddMult(const Vector &x, Vector &y, const double a) const
cusparseCreateMatDescr(&matA_descr);
cusparseSetMatIndexBase(matA_descr, CUSPARSE_INDEX_BASE_ZERO);
cusparseSetMatType(matA_descr, CUSPARSE_MATRIX_TYPE_GENERAL);
-
#endif
-
initBuffers = true;
}
// Allocate kernel space. Buffer is shared between different sparsemats
size_t newBufferSize = 0;
-#if CUDA_VERSION >= 11020
- cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
- matA_descr,
- vecX_descr, &beta, vecY_descr, CUDA_R_64F,
- CUSPARSE_SPMV_CSR_ALG1, &newBufferSize);
-#elif CUDA_VERSION >= 10010
cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
matA_descr,
vecX_descr, &beta, vecY_descr, CUDA_R_64F,
- CUSPARSE_CSRMV_ALG1, &newBufferSize);
-#endif
+ MFEM_CUSPARSE_ALG, &newBufferSize);
// Check if we need to resize
if (newBufferSize > bufferSize)
@@ -707,30 +703,22 @@ void SparseMatrix::AddMult(const Vector &x, Vector &y, const double a) const
CuMemAlloc(&dBuffer, bufferSize);
}
-#if CUDA_VERSION >= 11020
- // Update input/output vectors
- cusparseDnVecSetValues(vecX_descr, const_cast<double *>(d_x));
- cusparseDnVecSetValues(vecY_descr, d_y);
-
- // Y = alpha A * X + beta * Y
- cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA_descr,
- vecX_descr, &beta, vecY_descr, CUDA_R_64F, CUSPARSE_SPMV_CSR_ALG1, dBuffer);
-#elif CUDA_VERSION >= 10010
+#if CUDA_VERSION >= 10010
// Update input/output vectors
cusparseDnVecSetValues(vecX_descr, const_cast<double *>(d_x));
cusparseDnVecSetValues(vecY_descr, d_y);
// Y = alpha A * X + beta * Y
cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA_descr,
- vecX_descr, &beta, vecY_descr, CUDA_R_64F, CUSPARSE_CSRMV_ALG1, dBuffer);
+ vecX_descr, &beta, vecY_descr, CUDA_R_64F, MFEM_CUSPARSE_ALG, dBuffer);
#else
cusparseDcsrmv(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
Height(), Width(), J.Capacity(),
&alpha, matA_descr,
const_cast<double *>(d_A), const_cast<int *>(d_I), const_cast<int *>(d_J),
const_cast<double *>(d_x), &beta, d_y);
-#endif
-#endif
+#endif // CUDA_VERSION >= 10010
+#endif // MFEM_USE_CUDA
}
else
{

View File

@ -189,6 +189,7 @@ class Mfem(Package, CudaPackage, ROCmPackage):
conflicts('+umpire', when='mfem@:4.0') conflicts('+umpire', when='mfem@:4.0')
conflicts('+amgx', when='mfem@:4.1') conflicts('+amgx', when='mfem@:4.1')
conflicts('+amgx', when='~cuda') conflicts('+amgx', when='~cuda')
conflicts('+mpi~cuda ^hypre+cuda')
conflicts('+superlu-dist', when='~mpi') conflicts('+superlu-dist', when='~mpi')
conflicts('+strumpack', when='~mpi') conflicts('+strumpack', when='~mpi')
@ -301,6 +302,7 @@ class Mfem(Package, CudaPackage, ROCmPackage):
patch('mfem-4.2-slepc.patch', when='@4.2.0+slepc') patch('mfem-4.2-slepc.patch', when='@4.2.0+slepc')
patch('mfem-4.2-petsc-3.15.0.patch', when='@4.2.0+petsc ^petsc@3.15.0:') patch('mfem-4.2-petsc-3.15.0.patch', when='@4.2.0+petsc ^petsc@3.15.0:')
patch('mfem-4.3-hypre-2.23.0.patch', when='@4.3.0') patch('mfem-4.3-hypre-2.23.0.patch', when='@4.3.0')
patch('mfem-4.3-cusparse-11.4.patch', when='@4.3.0+cuda')
# Patch to fix MFEM makefile syntax error. See # Patch to fix MFEM makefile syntax error. See
# https://github.com/mfem/mfem/issues/1042 for the bug report and # https://github.com/mfem/mfem/issues/1042 for the bug report and