From 7a82c703c751e1f4f30d304eb06252460f1aff9c Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Tue, 28 Jan 2025 13:37:50 +0100 Subject: [PATCH] JAX: add v0.4.32+ (#46346) * JAX: add v0.4.34 * Disable search for clang * Update CUDA flags * Add py-jax 0.4.33, comment out until py-jaxlib 0.4.33 is also released * Fix GCC build * Try TF_NVCC_CLANG * py-jax: add v0.4.34 * jax no longer has separate tags for jaxlib * Install compiled wheel * Join path before glob * Wheel is in spack stage, not tmp path * Add 0.4.35 * Add newer versions * Build system has been refactored yet again * Drop clang * Fix build with source tarball, rocm support * Support GCC * Remove clang-specific compiler flags * enable_cuda flag was removed * Fix logic * py-jax: add v0.4.38 * Add patch to fix GCC support * Patch no longer needed * Skip patching, directly pass flags * New flags * Remove unused import * Patch changed * Use older version of patch * Newer patch * Add CUDA symlink * Symlink more directories * Recursive symlink * Import function * Recursive search * Undo cuda changes * Add v0.5.0 * I quit --- .../repos/builtin/packages/py-jax/package.py | 43 +++-- .../builtin/packages/py-jaxlib/package.py | 153 ++++++++++-------- 2 files changed, 122 insertions(+), 74 deletions(-) diff --git a/var/spack/repos/builtin/packages/py-jax/package.py b/var/spack/repos/builtin/packages/py-jax/package.py index 69a64564316..f99d02a83a2 100644 --- a/var/spack/repos/builtin/packages/py-jax/package.py +++ b/var/spack/repos/builtin/packages/py-jax/package.py @@ -7,22 +7,26 @@ class PyJax(PythonPackage): - """JAX is Autograd and XLA, brought together for high-performance - machine learning research. With its updated version of Autograd, - JAX can automatically differentiate native Python and NumPy - functions. It can differentiate through loops, branches, - recursion, and closures, and it can take derivatives of - derivatives of derivatives. It supports reverse-mode - differentiation (a.k.a. backpropagation) via grad as well as - forward-mode differentiation, and the two can be composed - arbitrarily to any order.""" + """Differentiate, compile, and transform Numpy code. - homepage = "https://github.com/google/jax" + JAX is a Python library for accelerator-oriented array computation and program transformation, + designed for high-performance numerical computing and large-scale machine learning. + """ + + homepage = "https://github.com/jax-ml/jax" pypi = "jax/jax-0.4.27.tar.gz" license("Apache-2.0") maintainers("adamjstewart", "jonas-eschle") + # version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8") + # version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8") + # version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b") + # version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a") + # version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e") + # version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db") + # version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92") + # version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08") version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287") version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577") version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186") @@ -59,9 +63,11 @@ class PyJax(PythonPackage): # setup.py depends_on("python@3.10:", when="@0.4.31:") depends_on("python@3.9:", when="@0.4.14:") + depends_on("py-ml-dtypes@0.4:", when="@0.4.29,0.4.35:") depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") + depends_on("py-numpy@1.25:", when="@0.5:") depends_on("py-numpy@1.24:", when="@0.4.31:") depends_on("py-numpy@1.22:", when="@0.4.14:") depends_on("py-numpy@1.21:", when="@0.4.7:") @@ -69,6 +75,7 @@ class PyJax(PythonPackage): # https://github.com/google/jax/issues/19246 depends_on("py-numpy@:1", when="@:0.4.25") depends_on("py-opt-einsum") + depends_on("py-scipy@1.11.1:", when="@0.5:") depends_on("py-scipy@1.10:", when="@0.4.31:") depends_on("py-scipy@1.9:", when="@0.4.19:") depends_on("py-scipy@1.7:", when="@0.4.7:") @@ -77,6 +84,14 @@ class PyJax(PythonPackage): # jax/_src/lib/__init__.py # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4 for v in [ + # "0.5.0", + # "0.4.38", + # "0.4.37", + # "0.4.36", + # "0.4.35", + # "0.4.34", + # "0.4.33", + # "0.4.32", "0.4.31", "0.4.30", "0.4.29", @@ -110,6 +125,13 @@ class PyJax(PythonPackage): depends_on(f"py-jaxlib@:{v}", when=f"@{v}") # See _minimum_jaxlib_version in jax/version.py + # depends_on("py-jaxlib@0.5:", when="@0.5:") + # depends_on("py-jaxlib@0.4.38:", when="@0.4.38:") + # depends_on("py-jaxlib@0.4.36:", when="@0.4.36:") + # depends_on("py-jaxlib@0.4.35:", when="@0.4.35:") + # depends_on("py-jaxlib@0.4.34:", when="@0.4.34:") + # depends_on("py-jaxlib@0.4.33:", when="@0.4.33:") + # depends_on("py-jaxlib@0.4.32:", when="@0.4.32:") depends_on("py-jaxlib@0.4.30:", when="@0.4.31:") depends_on("py-jaxlib@0.4.27:", when="@0.4.28:") depends_on("py-jaxlib@0.4.23:", when="@0.4.27:") @@ -124,5 +146,4 @@ class PyJax(PythonPackage): depends_on("py-jaxlib@0.4.1:", when="@0.4.2:") # Historical dependencies - depends_on("py-ml-dtypes@0.4:", when="@0.4.29") depends_on("py-importlib-metadata@4.6:", when="@0.4.11:0.4.30 ^python@:3.9") diff --git a/var/spack/repos/builtin/packages/py-jaxlib/package.py b/var/spack/repos/builtin/packages/py-jaxlib/package.py index 80cfa8c6006..337576cb3cd 100644 --- a/var/spack/repos/builtin/packages/py-jaxlib/package.py +++ b/var/spack/repos/builtin/packages/py-jaxlib/package.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: (Apache-2.0 OR MIT) -import tempfile +import glob from spack.build_systems.python import PythonPipBuilder from spack.package import * @@ -26,17 +26,27 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): - """XLA library for Jax""" + """XLA library for Jax. - homepage = "https://github.com/google/jax" - url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.4.27.tar.gz" + jaxlib is the support library for JAX. While JAX itself is a pure Python package, + jaxlib contains the binary (C/C++) parts of the library, including Python bindings, + the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. + """ - tmp_path = "" - buildtmp = "" + homepage = "https://github.com/jax-ml/jax" + url = "https://github.com/jax-ml/jax/archive/refs/tags/jax-v0.4.34.tar.gz" license("Apache-2.0") maintainers("adamjstewart", "jonas-eschle") + # version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9") + # version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c") + # version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870") + # version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77") + # version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212") + # version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb") + # version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31") + # version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2") version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94") version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6") version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246") @@ -57,12 +67,12 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): version("0.4.4", sha256="881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806") version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910") - depends_on("c", type="build") - depends_on("cxx", type="build") - variant("cuda", default=True, description="Build with CUDA enabled") variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda") + depends_on("c", type="build") + depends_on("cxx", type="build") + # docs/installation.md (Compatible with) with when("+cuda"): depends_on("cuda@12.1:", when="@0.4.26:") @@ -98,32 +108,42 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): depends_on("py-build", when="@0.4.14:") with default_args(type=("build", "run")): + # Based on PyPI wheels depends_on("python@3.10:", when="@0.4.31:") depends_on("python@3.9:", when="@0.4.14:") depends_on("python@3.8:", when="@0.4.6:") - # Based on PyPI wheels depends_on("python@:3.13") depends_on("python@:3.12", when="@:0.4.33") depends_on("python@:3.11", when="@:0.4.16") # jaxlib/setup.py + depends_on("py-scipy@1.11.1:", when="@0.5:") depends_on("py-scipy@1.10:", when="@0.4.31:") depends_on("py-scipy@1.9:", when="@0.4.19:") depends_on("py-scipy@1.7:", when="@0.4.7:") depends_on("py-scipy@1.5:") + depends_on("py-numpy@1.25:", when="@0.5:") depends_on("py-numpy@1.24:", when="@0.4.31:") depends_on("py-numpy@1.22:", when="@0.4.14:") depends_on("py-numpy@1.21:", when="@0.4.7:") depends_on("py-numpy@1.20:", when="@0.3:") + # https://github.com/google/jax/issues/19246 + depends_on("py-numpy@:1", when="@:0.4.25") + depends_on("py-ml-dtypes@0.4:", when="@0.4.29") depends_on("py-ml-dtypes@0.2:", when="@0.4.14:") depends_on("py-ml-dtypes@0.1:", when="@0.4.9:") depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:") - # Historical dependencies - # https://github.com/google/jax/issues/19246 - depends_on("py-numpy@:1", when="@:0.4.25") - depends_on("py-ml-dtypes@0.4:", when="@0.4.29") - + patch( + "https://github.com/jax-ml/jax/commit/f62af6457a6cc575a7b1ada08d541f0dd0eb5765.patch?full_index=1", + sha256="d3b7ea2cfeba927e40a11f07e4cbf80939f7fe69448c9eb55231a93bd64e5c02", + when="@0.4.36:0.4.38", + ) + patch( + "https://github.com/jax-ml/jax/pull/25473.patch?full_index=1", + sha256="9d6977bc32046600bf8b15863251283fe7546896340367a7f14e3dccf418b4fe", + when="@0.4.36:0.4.37", + ) patch( "https://github.com/google/jax/pull/20101.patch?full_index=1", sha256="4dfb9f32d4eeb0a0fb3a6f4124c4170e3fe49511f1b768cd634c78d489962275", @@ -144,58 +164,65 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage): # https://github.com/google/jax/issues/19992 conflicts("@0.4.4:", when="target=ppc64le:") - def patch(self): - self.tmp_path = tempfile.mkdtemp(prefix="spack") - self.buildtmp = tempfile.mkdtemp(prefix="spack") - filter_file( - "build --spawn_strategy=standalone", - f""" -# Limit CPU workers to spack jobs instead of using all HOST_CPUS. -build --spawn_strategy=standalone -build --local_cpu_resources={make_jobs} -""".strip(), - ".bazelrc", - string=True, - ) - filter_file( - 'f"--output_path={output_path}",', - 'f"--output_path={output_path}",' - f' "--sources_path={self.tmp_path}",' - ' "--nohome_rc",' - ' "--nosystem_rc",' - f' "--jobs={make_jobs}",', - "build/build.py", - string=True, - ) - build_wheel = join_path("build", "build_wheel.py") - if self.spec.satisfies("@0.4.14:"): - build_wheel = join_path("jaxlib", "tools", "build_wheel.py") - filter_file( - "args = parser.parse_args()", - "args, junk = parser.parse_known_args()", - build_wheel, - string=True, - ) + def url_for_version(self, version): + url = "https://github.com/jax-ml/jax/archive/refs/tags/{}-v{}.tar.gz" + if version >= Version("0.4.33"): + name = "jax" + else: + name = "jaxlib" + return url.format(name, version) def install(self, spec, prefix): - args = [] - args.append("build/build.py") + # https://jax.readthedocs.io/en/latest/developer.html + args = ["build/build.py"] + + if spec.satisfies("@0.4.36:"): + args.append("build") + + if spec.satisfies("+cuda"): + args.append("--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt") + elif spec.satisfies("+rocm"): + args.append("--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt") + else: + args.append("--wheels=jaxlib") + + if spec.satisfies("@0.4.32:"): + if spec.satisfies("%clang"): + args.append("--use_clang=true") + else: + args.append("--use_clang=false") + if "+cuda" in spec: - args.append("--enable_cuda") - args.append("--cuda_path={0}".format(self.spec["cuda"].prefix)) - args.append("--cudnn_path={0}".format(self.spec["cudnn"].prefix)) capabilities = CudaPackage.compute_capabilities(spec.variants["cuda_arch"].value) - args.append("--cuda_compute_capabilities={0}".format(",".join(capabilities))) - args.append( - "--bazel_startup_options=" - "--output_user_root={0}".format(self.wrapped_package_object.buildtmp) - ) + args.append(f"--cuda_compute_capabilities={','.join(capabilities)}") + if spec.satisfies("@:0.4.35"): + args.append("--enable_cuda") + if spec.satisfies("@0.4.32:"): + args.extend( + [ + f"--bazel_options=--repo_env=LOCAL_CUDA_PATH={spec['cuda'].prefix}", + f"--bazel_options=--repo_env=LOCAL_CUDNN_PATH={spec['cudnn'].prefix}", + ] + ) + else: + args.extend( + [f"--cuda_path={spec['cuda'].prefix}", f"--cudnn_path={spec['cudnn'].prefix}"] + ) + + if "+nccl" in spec and spec.satisfies("@0.4.32:"): + args.append(f"--bazel_options=--repo_env=LOCAL_NCCL_PATH={spec['nccl'].prefix}") + if "+rocm" in spec: - args.append("--enable_rocm") - args.append("--rocm_path={0}".format(self.spec["hip"].prefix)) + args.extend(["--enable_rocm", f"--rocm_path={self.spec['hip'].prefix}"]) + + args.extend( + [ + f"--bazel_options=--jobs={make_jobs}", + "--bazel_startup_options=--nohome_rc", + "--bazel_startup_options=--nosystem_rc", + ] + ) python(*args) - with working_dir(self.wrapped_package_object.tmp_path): - pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", ".") - remove_linked_tree(self.wrapped_package_object.tmp_path) - remove_linked_tree(self.wrapped_package_object.buildtmp) + whl = glob.glob(join_path("dist", "*.whl"))[0] + pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)