petsc: improve hipsparse compat (#40311)

Co-authored-by: Satish Balay <balay@mcs.anl.gov>
This commit is contained in:
Harmen Stoppels 2023-12-11 10:30:14 +01:00 committed by GitHub
parent a6c32c80ab
commit 525809632e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 127 additions and 15 deletions

View File

@ -0,0 +1,85 @@
commit 9b52b1224039b470f0f450943ce503af1df37b00
Author: Satish Balay <balay@mcs.anl.gov>
Date: Fri Oct 6 15:19:34 2023 -0500
hip-6.0 fix
diff --git a/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cpp b/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cpp
index e6be2076975..0c388c90ca3 100644
--- a/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cpp
+++ b/src/mat/impls/aij/seq/seqhipsparse/aijhipsparse.hip.cpp
@@ -1259,14 +1259,22 @@ static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Vec x)
/* Solve L*y = b */
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
+ #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
+ PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
+ fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
+ #else
PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L)); // hipsparseSpSV_solve() secretely uses the external buffer used in hipsparseSpSV_analysis()!
-
+ #endif
/* Solve U*x = y */
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
+ #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
+ fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U));
+ #else
+ PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* U X = Y */
fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_U, fs->spsvBuffer_U));
-
+ #endif
PetscCall(VecHIPRestoreArrayRead(b, &barray));
PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
@@ -1309,14 +1317,22 @@ static PetscErrorCode MatSolveTranspose_SeqAIJHIPSPARSE_ILU0(Mat fact, Vec b, Ve
/* Solve Ut*y = b */
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
+ #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
+ PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
+ fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut));
+ #else
PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_U, /* Ut Y = X */
fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Ut, fs->spsvBuffer_Ut));
-
+ #endif
/* Solve Lt*x = y */
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
+ #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
+ PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
+ fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
+ #else
PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
-
+ #endif
PetscCall(VecHIPRestoreArrayRead(b, &barray));
PetscCall(VecHIPRestoreArrayWrite(x, &xarray));
PetscCall(PetscLogGpuTimeEnd());
@@ -1544,14 +1560,22 @@ static PetscErrorCode MatSolve_SeqAIJHIPSPARSE_ICC0(Mat fact, Vec b, Vec x)
/* Solve L*y = b */
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, (void *)barray));
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_Y, fs->Y));
+ #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
+ PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
+ fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L));
+ #else
PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_NON_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* L Y = X */
fs->dnVecDescr_X, fs->dnVecDescr_Y, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_L, fs->spsvBuffer_L));
-
+ #endif
/* Solve Lt*x = y */
PetscCallHIPSPARSE(hipsparseDnVecSetValues(fs->dnVecDescr_X, xarray));
+ #if PETSC_PKG_HIP_VERSION_EQ(5, 6, 0)
+ PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
+ fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt));
+ #else
PetscCallHIPSPARSE(hipsparseSpSV_solve(fs->handle, HIPSPARSE_OPERATION_TRANSPOSE, &PETSC_HIPSPARSE_ONE, fs->spMatDescr_L, /* Lt X = Y */
fs->dnVecDescr_Y, fs->dnVecDescr_X, hipsparse_scalartype, HIPSPARSE_SPSV_ALG_DEFAULT, fs->spsvDescr_Lt, fs->spsvBuffer_Lt));
-
+ #endif
PetscCall(VecHIPRestoreArrayRead(b, &barray));
PetscCall(VecHIPRestoreArrayWrite(x, &xarray));

View File

@ -0,0 +1,20 @@
diff --git a/src/vec/is/sf/impls/basic/hip/sfhip.hip.cpp b/src/vec/is/sf/impls/basic/hip/sfhip.hip.cpp
index a39933c6893..6ef9f513bd6 100644
--- a/src/vec/is/sf/impls/basic/hip/sfhip.hip.cpp
+++ b/src/vec/is/sf/impls/basic/hip/sfhip.hip.cpp
@@ -471,6 +471,7 @@ __device__ static float atomicMax(float *address, float val)
#endif
/* As of ROCm 3.10 llint atomicMin/Max(llint*, llint) is not supported */
+#if PETSC_PKG_HIP_VERSION_LT(5, 7, 0)
__device__ static llint atomicMin(llint *address, llint val)
{
ullint *address_as_ull = (ullint *)(address);
@@ -492,6 +493,7 @@ __device__ static llint atomicMax(llint *address, llint val)
} while (assumed != old);
return (llint)old;
}
+#endif
template <typename Type>
struct AtomicMin {

View File

@ -161,12 +161,17 @@ class Petsc(Package, CudaPackage, ROCmPackage):
variant("kokkos", default=False, description="Activates support for kokkos and kokkos-kernels")
variant("fortran", default=True, description="Activates fortran support")
# https://github.com/spack/spack/issues/37416
conflicts("^rocprim@5.3.0:5.3.2", when="+rocm")
# petsc 3.20 has workaround for breaking change in hipsparseSpSV_solve api,
# but it seems to misdetect hipsparse@5.6.1 as 5.6.0, so the workaround
# only makes things worse
conflicts("^hipsparse@5.6", when="+rocm @3.20.0")
with when("+rocm"):
# https://github.com/spack/spack/issues/37416
conflicts("^rocprim@5.3.0:5.3.2")
# hipsparse@5.6.0 broke hipsparseSpSV_solve() API, reverted in 5.6.1.
patch(
"https://gitlab.com/petsc/petsc/-/commit/ef7140cce45367033b48bbd2624dfd2b6aa4b997.diff",
when="@3.20.0",
sha256="ba327f8b2a0fa45209dfb7a4278f3e9a323965b5a668be204c1c77c17a963a7f",
)
patch("hip-5.6.0-for-3.18.diff", when="@3.18:3.19 ^hipsparse@5.6.0")
patch("hip-5.7-plus-for-3.18.diff", when="@3.18:3.19 ^hipsparse@5.7:")
# 3.8.0 has a build issue with MKL - so list this conflict explicitly
conflicts("^intel-mkl", when="@3.8.0")
@ -225,15 +230,17 @@ def check_fortran_compiler(self):
depends_on("mpi", when="+mpi")
depends_on("cuda", when="+cuda")
depends_on("hip", when="+rocm")
depends_on("hipblas", when="+rocm")
depends_on("hipsparse", when="+rocm")
depends_on("hipsolver", when="+rocm")
depends_on("rocsparse", when="+rocm")
depends_on("rocsolver", when="+rocm")
depends_on("rocblas", when="+rocm")
depends_on("rocrand", when="+rocm")
depends_on("rocthrust", when="+rocm")
depends_on("rocprim", when="+rocm")
with when("+rocm"):
depends_on("hipblas")
depends_on("hipsparse")
depends_on("hipsolver")
depends_on("rocsparse")
depends_on("rocsolver")
depends_on("rocblas")
depends_on("rocrand")
depends_on("rocthrust")
depends_on("rocprim")
# Build dependencies
depends_on("python@2.6:2.8,3.4:3.8", when="@:3.13", type="build")