diff --git a/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-rocm/spack.yaml b/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-rocm/spack.yaml index f3bdf578fa1..a285c131f7c 100644 --- a/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-rocm/spack.yaml +++ b/share/spack/gitlab/cloud_pipelines/stacks/ml-linux-x86_64-rocm/spack.yaml @@ -27,9 +27,8 @@ spack: - py-transformers # JAX - # Does not yet support Spack-installed ROCm - # - py-jax - # - py-jaxlib + - py-jax + - py-jaxlib # Keras - py-keras backend=tensorflow diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py index f99d02a83a2..a09afc11792 100644 --- a/var/spack/repos/builtin/packages/py-jax/package.py +++ b/var/spack/repos/builtin/packages/py-jax/package.py @@ -20,13 +20,13 @@ class PyJax(PythonPackage): maintainers("adamjstewart", "jonas-eschle") # version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8") - # version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8") - # version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b") - # version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a") - # version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e") - # version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db") - # version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92") - # version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08") + version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8") + version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b") + version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a") + version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e") + version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db") + version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92") + version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08") version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287") version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577") version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186") @@ -85,13 +85,13 @@ class PyJax(PythonPackage): # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4 for v in [ # "0.5.0", - # "0.4.38", - # "0.4.37", - # "0.4.36", - # "0.4.35", - # "0.4.34", - # "0.4.33", - # "0.4.32", + "0.4.38", + "0.4.37", + "0.4.36", + "0.4.35", + "0.4.34", + "0.4.33", + "0.4.32", "0.4.31", "0.4.30", "0.4.29", @@ -126,12 +126,12 @@ class PyJax(PythonPackage): # See _minimum_jaxlib_version in jax/version.py # depends_on("py-jaxlib@0.5:", when="@0.5:") - # depends_on("py-jaxlib@0.4.38:", when="@0.4.38:") - # depends_on("py-jaxlib@0.4.36:", when="@0.4.36:") - # depends_on("py-jaxlib@0.4.35:", when="@0.4.35:") - # depends_on("py-jaxlib@0.4.34:", when="@0.4.34:") - # depends_on("py-jaxlib@0.4.33:", when="@0.4.33:") - # depends_on("py-jaxlib@0.4.32:", when="@0.4.32:") + depends_on("py-jaxlib@0.4.38:", when="@0.4.38:") + depends_on("py-jaxlib@0.4.36:", when="@0.4.36:") + depends_on("py-jaxlib@0.4.35:", when="@0.4.35:") + depends_on("py-jaxlib@0.4.34:", when="@0.4.34:") + depends_on("py-jaxlib@0.4.33:", when="@0.4.33:") + depends_on("py-jaxlib@0.4.32:", when="@0.4.32:") depends_on("py-jaxlib@0.4.30:", when="@0.4.31:") depends_on("py-jaxlib@0.4.27:", when="@0.4.28:") depends_on("py-jaxlib@0.4.23:", when="@0.4.27:") diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index cd679311a25..a43ce8a1dec 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -8,20 +8,27 @@ from spack.package import * rocm_dependencies = [ - "hsa-rocr-dev", + "comgr", "hip", - "rccl", - "rocprim", + "hipblas", + "hipblaslt", "hipcub", - "rocthrust", - "roctracer-dev", - "rocrand", - "hipsparse", "hipfft", - "rocfft", - "rocblas", + "hiprand", + "hipsolver", + "hipsparse", + "hsa-rocr-dev", "miopen-hip", + "rccl", + "rocblas", + "rocfft", "rocminfo", + "rocprim", + "rocrand", + "rocsolver", + "rocsparse", + "roctracer-dev", + "rocm-core", ] @@ -39,14 +46,17 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): license("Apache-2.0") maintainers("adamjstewart", "jonas-eschle") + # version("0.5.3", sha256="1094581a30ec069965f4e3e67d60262570cc3dd016adc62073bc24347b14270c") + # version("0.5.2", sha256="8e9de1e012dd65fc4a9eec8af4aa2bf6782767130a5d8e1c1e342b7d658280fe") + # version("0.5.1", sha256="e74b1209517682075933f757d646b73040d09fe39ee3e9e4cd398407dd0902d2") # version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9") - # version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c") - # version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870") - # version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77") - # version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212") - # version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb") - # version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31") - # version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2") + version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c") + version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870") + version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77") + version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212") + version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb") + version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31") + version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2") version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94") version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6") version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246") @@ -93,6 +103,10 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): for pkg_dep in rocm_dependencies: depends_on(f"{pkg_dep}@6:", when="@0.4.28:") depends_on(pkg_dep) + depends_on("rocprofiler-register", when="^hip@6.2:") + depends_on("hipblas-common", when="^hip@6.3:") + depends_on("hsakmt-roct", when="^hip@:6.2") + depends_on("llvm-amdgpu") depends_on("py-nanobind") with default_args(type="build"): @@ -113,6 +127,7 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): depends_on("python@3.9:", when="@0.4.14:") depends_on("python@3.8:", when="@0.4.6:") depends_on("python@:3.13") + depends_on("python@:3.12", when="+rocm") depends_on("python@:3.12", when="@:0.4.33") depends_on("python@:3.11", when="@:0.4.16") @@ -167,6 +182,22 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): # Fails to build with freshly released CUDA (#48708). conflicts("^cuda@12.8:", when="@:0.4.31") + # external CUDA is not supported https://github.com/jax-ml/jax/issues/23689 + conflicts("+cuda", when="@0.4.32:") + + # aarch64 is not supported https://github.com/jax-ml/jax/issues/25598 + conflicts("target=aarch64:", when="@0.4.32:") + + resource( + name="xla", + url="https://github.com/ROCm/xla/archive/07543ab117699a57c1267b453a62f89b1d5953fd.tar.gz", + sha256="cee377479654201c61cc3f230d89603cd589525fea2faf44564a23c70ba1448d", + expand=True, + destination="", + placement="xla", + when="@0.4.38:0.5.2 +rocm", + ) + def url_for_version(self, version): url = "https://github.com/jax-ml/jax/archive/refs/tags/{}-v{}.tar.gz" if version >= Version("0.4.33"): @@ -175,6 +206,20 @@ def url_for_version(self, version): name = "jaxlib" return url.format(name, version) + def setup_build_environment(self, env): + spec = self.spec + if spec.satisfies("@0.4.38: +rocm") and not spec["hip"].external: + if spec.satisfies("^hip@6.2:"): + rocm_dependencies.append("rocprofiler-register") + if spec.satisfies("^hip@6.3:"): + rocm_dependencies.append("hipblas-common") + else: + rocm_dependencies.append("hsakmt-roct") + env.set("LLVM_PATH", spec["llvm-amdgpu"].prefix) + for pkg_dep in rocm_dependencies: + env.prepend_path("TF_ROCM_MULTIPLE_PATHS", spec[pkg_dep].prefix) + env.prune_duplicate_paths("TF_ROCM_MULTIPLE_PATHS") + def install(self, spec, prefix): # https://jax.readthedocs.io/en/latest/developer.html args = ["build/build.py"] @@ -216,7 +261,15 @@ def install(self, spec, prefix): args.append(f"--bazel_options=--repo_env=LOCAL_NCCL_PATH={spec['nccl'].prefix}") if "+rocm" in spec: - args.extend(["--enable_rocm", f"--rocm_path={self.spec['hip'].prefix}"]) + args.append(f"--rocm_path={self.spec['hip'].prefix}") + if spec.satisfies("@:0.4.35"): + args.append("--enable_rocm") + if spec.satisfies("@0.4.38:") and not spec["hip"].external: + args.append("--bazel_options=--@local_config_rocm//rocm:rocm_path_type=multiple") + if spec.satisfies("@0.4.38:0.5.2"): + args.append( + f"--bazel_options=--override_repository=xla={self.stage.source_path}/xla" + ) args.extend( [ @@ -227,5 +280,6 @@ def install(self, spec, prefix): ) python(*args) - whl = glob.glob(join_path("dist", "*.whl"))[0] - pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl) + + for whl in glob.glob(join_path("dist", "*.whl")): + pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)