Add a patch for mfem v4.3 to support cusparse >= 11.4 (#27267)
This commit is contained in:
parent
723f2f465b
commit
ceabb96c89
@ -0,0 +1,80 @@
|
||||
diff --git a/linalg/sparsemat.cpp b/linalg/sparsemat.cpp
|
||||
index 12136e035..0be73cf7b 100644
|
||||
--- a/linalg/sparsemat.cpp
|
||||
+++ b/linalg/sparsemat.cpp
|
||||
@@ -33,7 +33,12 @@ int SparseMatrix::SparseMatrixCount = 0;
|
||||
cusparseHandle_t SparseMatrix::handle = nullptr;
|
||||
size_t SparseMatrix::bufferSize = 0;
|
||||
void * SparseMatrix::dBuffer = nullptr;
|
||||
-#endif
|
||||
+# if CUSPARSE_VERSION >= 11400
|
||||
+# define MFEM_CUSPARSE_ALG CUSPARSE_SPMV_CSR_ALG1
|
||||
+# else
|
||||
+# define MFEM_CUSPARSE_ALG CUSPARSE_CSRMV_ALG1
|
||||
+# endif // CUSPARSE_VERSION >= 11400
|
||||
+#endif // MFEM_USE_CUDA
|
||||
|
||||
void SparseMatrix::InitCuSparse()
|
||||
{
|
||||
@@ -679,25 +684,16 @@ void SparseMatrix::AddMult(const Vector &x, Vector &y, const double a) const
|
||||
cusparseCreateMatDescr(&matA_descr);
|
||||
cusparseSetMatIndexBase(matA_descr, CUSPARSE_INDEX_BASE_ZERO);
|
||||
cusparseSetMatType(matA_descr, CUSPARSE_MATRIX_TYPE_GENERAL);
|
||||
-
|
||||
#endif
|
||||
-
|
||||
initBuffers = true;
|
||||
}
|
||||
// Allocate kernel space. Buffer is shared between different sparsemats
|
||||
size_t newBufferSize = 0;
|
||||
|
||||
-#if CUDA_VERSION >= 11020
|
||||
- cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
- matA_descr,
|
||||
- vecX_descr, &beta, vecY_descr, CUDA_R_64F,
|
||||
- CUSPARSE_SPMV_CSR_ALG1, &newBufferSize);
|
||||
-#elif CUDA_VERSION >= 10010
|
||||
cusparseSpMV_bufferSize(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha,
|
||||
matA_descr,
|
||||
vecX_descr, &beta, vecY_descr, CUDA_R_64F,
|
||||
- CUSPARSE_CSRMV_ALG1, &newBufferSize);
|
||||
-#endif
|
||||
+ MFEM_CUSPARSE_ALG, &newBufferSize);
|
||||
|
||||
// Check if we need to resize
|
||||
if (newBufferSize > bufferSize)
|
||||
@@ -707,30 +703,22 @@ void SparseMatrix::AddMult(const Vector &x, Vector &y, const double a) const
|
||||
CuMemAlloc(&dBuffer, bufferSize);
|
||||
}
|
||||
|
||||
-#if CUDA_VERSION >= 11020
|
||||
- // Update input/output vectors
|
||||
- cusparseDnVecSetValues(vecX_descr, const_cast<double *>(d_x));
|
||||
- cusparseDnVecSetValues(vecY_descr, d_y);
|
||||
-
|
||||
- // Y = alpha A * X + beta * Y
|
||||
- cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA_descr,
|
||||
- vecX_descr, &beta, vecY_descr, CUDA_R_64F, CUSPARSE_SPMV_CSR_ALG1, dBuffer);
|
||||
-#elif CUDA_VERSION >= 10010
|
||||
+#if CUDA_VERSION >= 10010
|
||||
// Update input/output vectors
|
||||
cusparseDnVecSetValues(vecX_descr, const_cast<double *>(d_x));
|
||||
cusparseDnVecSetValues(vecY_descr, d_y);
|
||||
|
||||
// Y = alpha A * X + beta * Y
|
||||
cusparseSpMV(handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, matA_descr,
|
||||
- vecX_descr, &beta, vecY_descr, CUDA_R_64F, CUSPARSE_CSRMV_ALG1, dBuffer);
|
||||
+ vecX_descr, &beta, vecY_descr, CUDA_R_64F, MFEM_CUSPARSE_ALG, dBuffer);
|
||||
#else
|
||||
cusparseDcsrmv(handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
Height(), Width(), J.Capacity(),
|
||||
&alpha, matA_descr,
|
||||
const_cast<double *>(d_A), const_cast<int *>(d_I), const_cast<int *>(d_J),
|
||||
const_cast<double *>(d_x), &beta, d_y);
|
||||
-#endif
|
||||
-#endif
|
||||
+#endif // CUDA_VERSION >= 10010
|
||||
+#endif // MFEM_USE_CUDA
|
||||
}
|
||||
else
|
||||
{
|
@ -189,6 +189,7 @@ class Mfem(Package, CudaPackage, ROCmPackage):
|
||||
conflicts('+umpire', when='mfem@:4.0')
|
||||
conflicts('+amgx', when='mfem@:4.1')
|
||||
conflicts('+amgx', when='~cuda')
|
||||
conflicts('+mpi~cuda ^hypre+cuda')
|
||||
|
||||
conflicts('+superlu-dist', when='~mpi')
|
||||
conflicts('+strumpack', when='~mpi')
|
||||
@ -301,6 +302,7 @@ class Mfem(Package, CudaPackage, ROCmPackage):
|
||||
patch('mfem-4.2-slepc.patch', when='@4.2.0+slepc')
|
||||
patch('mfem-4.2-petsc-3.15.0.patch', when='@4.2.0+petsc ^petsc@3.15.0:')
|
||||
patch('mfem-4.3-hypre-2.23.0.patch', when='@4.3.0')
|
||||
patch('mfem-4.3-cusparse-11.4.patch', when='@4.3.0+cuda')
|
||||
|
||||
# Patch to fix MFEM makefile syntax error. See
|
||||
# https://github.com/mfem/mfem/issues/1042 for the bug report and
|
||||
|
Loading…
Reference in New Issue
Block a user