py-tensorflow: add 2.18.0-rocm-enhanced (#48711)

* py-tensorflow: add 2.18.0-rocm-enhanced

* fix style

* fix style

* fix style

* review changes

* review changes

* remove hipblaslt dependency

* remove ci changes and force ROCm 6.3.1 for newest TF

* remove rocm 6.3.1 dependency

* simplify configure fix
This commit is contained in:
afzpatel 2025-02-03 15:38:25 -05:00 committed by GitHub
parent 5c91667dab
commit dd98cfb839
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 18 deletions

View File

@ -3,7 +3,6 @@
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import glob
import os
import sys
import tempfile
@ -46,6 +45,11 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
license("Apache-2.0")
maintainers("adamjstewart", "aweits")
version(
"2.18.0-rocm-enhanced",
sha256="85f44bed166927b2e22db28f5c4e4538da22221fedd9c2f47c763c52a0e40814",
url="https://github.com/ROCm/tensorflow-upstream/archive/refs/tags/v2.18.0-rocm-enhanced.tar.gz",
)
version("2.18.0", sha256="d7876f4bb0235cac60eb6316392a7c48676729860da1ab659fb440379ad5186d")
version("2.17.1", sha256="2d3cfb48510f92f3a52fb05b820481c6f066a342a9f5296fe26d72c4ea757700")
version("2.17.0", sha256="9cc4d5773b8ee910079baaecb4086d0c28939f024dd74b33fc5e64779b6533dc")
@ -440,11 +444,25 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
conflicts("platform=darwin target=aarch64:", when="@:2.4")
# https://github.com/tensorflow/tensorflow/pull/39225
conflicts("target=aarch64:", when="@:2.2")
conflicts(
"~rocm",
when="@2.7.4-rocm-enhanced,2.11.0-rocm-enhanced,2.14-rocm-enhanced,2.16.1-rocm-enhanced",
)
conflicts("+rocm", when="@:2.7.4-a,2.7.4.0:2.11.0-a,2.11.0.0:2.14-a,2.14-z:2.16.1-a,2.16.1-z:")
rocm_versions = [
"2.7.4-rocm-enhanced",
"2.11.0-rocm-enhanced",
"2.14-rocm-enhanced",
"2.16.1-rocm-enhanced",
"2.18.0-rocm-enhanced",
]
rocm_conflicts = [
":2.7.4-a",
"2.7.4.0:2.11.0-a",
"2.11.0.0:2.14-a",
"2.14-z:2.16.1-a",
"2.16.1-z:2.18.0-a",
"2.18.0-z:",
]
conflicts("~rocm", when=f"@{','.join(rocm_versions)}")
conflicts("+rocm", when=f"@{','.join(rocm_conflicts)}")
# wheel 0.40 upgrades vendored packaging, trips over tensorflow-io-gcs-filesystem identifier
conflicts("^py-wheel@0.40:", when="@2.11:2.13")
@ -510,13 +528,14 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
patch(
"https://github.com/ROCm/tensorflow-upstream/commit/f4f4e8698b90755b0b5ea2d9da1933b0b988b111.patch?full_index=1",
sha256="a4c0fd62a0af3ba113c8933fa531dd17fa6667e507202a144715cd87fbdaf476",
when="@2.16.1-rocm-enhanced: +rocm",
when="@2.16.1-rocm-enhanced +rocm",
)
patch(
"https://github.com/ROCm/tensorflow-upstream/commit/8b7fcccb2914078737689347540cb79ace579bbb.patch?full_index=1",
sha256="75a61a79ce3aae51fda920f677f4dc045374b20e25628626eb37ca19c3a3b4c4",
when="@2.16.1-rocm-enhanced +rocm",
)
patch("set_jit_true.patch", when="@2.18.0-rocm-enhanced +rocm")
phases = ["configure", "build", "install"]
def flag_handler(self, name, flags):
@ -850,17 +869,10 @@ def post_configure_fixes(self):
with open(".tf_configure.bazelrc", mode="a") as f:
f.write('build --action_env LD_LIBRARY_PATH="' + slibs + '"')
if spec.satisfies("@2.16.1-rocm-enhanced +rocm"):
if os.path.exists(spec["llvm-amdgpu"].prefix.bin.clang):
filter_file(
"/usr/lib/llvm-17/bin/clang", spec["llvm-amdgpu"].prefix.bin.clang, ".bazelrc"
)
else:
filter_file(
"/usr/lib/llvm-17/bin/clang",
spec["llvm-amdgpu"].prefix.llvm.bin.clang,
".bazelrc",
)
if spec.satisfies("+rocm"):
before = r"/usr/lib/llvm-\d+/bin/clang"
after = spec["llvm-amdgpu"].prefix.bin.clang
filter_file(before, after, ".bazelrc")
filter_file("build:opt --copt=-march=native", "", ".tf_configure.bazelrc")
filter_file("build:opt --host_copt=-march=native", "", ".tf_configure.bazelrc")
@ -938,6 +950,11 @@ def build(self, spec, prefix):
args.append("--config=v2")
if self.spec.satisfies("@2.18.0-rocm-enhanced: +rocm"):
buildpath = join_path(
self.stage.source_path, "bazel-bin/tensorflow/tools/pip_package/wheel_house/"
)
args.append(f"--repo_env=OUTPUT_PATH={buildpath}")
# https://github.com/tensorflow/tensorflow/issues/63298
if self.spec.satisfies("@2.17:"):
args.append("//tensorflow/tools/pip_package:wheel")

View File

@ -0,0 +1,22 @@
diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
index f574a8da8fd..fc1fbf68bf8 100644
--- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl
+++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
@@ -360,7 +360,7 @@ def _gen_kernel_library(
extra_args = extra_args,
host_triple = host_triple,
gpu_archs = gpu_archs,
- jit = jit,
+ jit = True,
mlir_op = "{op}_{name}_{platform}_{type}_{output_type}.mlir".format(
op = op,
name = name,
@@ -370,7 +370,7 @@ def _gen_kernel_library(
),
tile_size = typed_tile_size,
unroll_factors = typed_unroll_factors,
- jit_i64_indexed_for_large_tensors = jit_i64_indexed_for_large_tensors,
+ jit_i64_indexed_for_large_tensors = False,
)
# We have to use a sh_test instead of build_test because it doesn't properly find the dependent targets.