py-tensorflow: add 2.18.0-rocm-enhanced (#48711)
* py-tensorflow: add 2.18.0-rocm-enhanced * fix style * fix style * fix style * review changes * review changes * remove hipblaslt dependency * remove ci changes and force ROCm 6.3.1 for newest TF * remove rocm 6.3.1 dependency * simplify configure fix
This commit is contained in:
parent
5c91667dab
commit
dd98cfb839
@ -3,7 +3,6 @@
|
||||
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
|
||||
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
@ -46,6 +45,11 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
|
||||
license("Apache-2.0")
|
||||
maintainers("adamjstewart", "aweits")
|
||||
|
||||
version(
|
||||
"2.18.0-rocm-enhanced",
|
||||
sha256="85f44bed166927b2e22db28f5c4e4538da22221fedd9c2f47c763c52a0e40814",
|
||||
url="https://github.com/ROCm/tensorflow-upstream/archive/refs/tags/v2.18.0-rocm-enhanced.tar.gz",
|
||||
)
|
||||
version("2.18.0", sha256="d7876f4bb0235cac60eb6316392a7c48676729860da1ab659fb440379ad5186d")
|
||||
version("2.17.1", sha256="2d3cfb48510f92f3a52fb05b820481c6f066a342a9f5296fe26d72c4ea757700")
|
||||
version("2.17.0", sha256="9cc4d5773b8ee910079baaecb4086d0c28939f024dd74b33fc5e64779b6533dc")
|
||||
@ -440,11 +444,25 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
|
||||
conflicts("platform=darwin target=aarch64:", when="@:2.4")
|
||||
# https://github.com/tensorflow/tensorflow/pull/39225
|
||||
conflicts("target=aarch64:", when="@:2.2")
|
||||
conflicts(
|
||||
"~rocm",
|
||||
when="@2.7.4-rocm-enhanced,2.11.0-rocm-enhanced,2.14-rocm-enhanced,2.16.1-rocm-enhanced",
|
||||
)
|
||||
conflicts("+rocm", when="@:2.7.4-a,2.7.4.0:2.11.0-a,2.11.0.0:2.14-a,2.14-z:2.16.1-a,2.16.1-z:")
|
||||
|
||||
rocm_versions = [
|
||||
"2.7.4-rocm-enhanced",
|
||||
"2.11.0-rocm-enhanced",
|
||||
"2.14-rocm-enhanced",
|
||||
"2.16.1-rocm-enhanced",
|
||||
"2.18.0-rocm-enhanced",
|
||||
]
|
||||
rocm_conflicts = [
|
||||
":2.7.4-a",
|
||||
"2.7.4.0:2.11.0-a",
|
||||
"2.11.0.0:2.14-a",
|
||||
"2.14-z:2.16.1-a",
|
||||
"2.16.1-z:2.18.0-a",
|
||||
"2.18.0-z:",
|
||||
]
|
||||
conflicts("~rocm", when=f"@{','.join(rocm_versions)}")
|
||||
conflicts("+rocm", when=f"@{','.join(rocm_conflicts)}")
|
||||
|
||||
# wheel 0.40 upgrades vendored packaging, trips over tensorflow-io-gcs-filesystem identifier
|
||||
conflicts("^py-wheel@0.40:", when="@2.11:2.13")
|
||||
|
||||
@ -510,13 +528,14 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
|
||||
patch(
|
||||
"https://github.com/ROCm/tensorflow-upstream/commit/f4f4e8698b90755b0b5ea2d9da1933b0b988b111.patch?full_index=1",
|
||||
sha256="a4c0fd62a0af3ba113c8933fa531dd17fa6667e507202a144715cd87fbdaf476",
|
||||
when="@2.16.1-rocm-enhanced: +rocm",
|
||||
when="@2.16.1-rocm-enhanced +rocm",
|
||||
)
|
||||
patch(
|
||||
"https://github.com/ROCm/tensorflow-upstream/commit/8b7fcccb2914078737689347540cb79ace579bbb.patch?full_index=1",
|
||||
sha256="75a61a79ce3aae51fda920f677f4dc045374b20e25628626eb37ca19c3a3b4c4",
|
||||
when="@2.16.1-rocm-enhanced +rocm",
|
||||
)
|
||||
patch("set_jit_true.patch", when="@2.18.0-rocm-enhanced +rocm")
|
||||
phases = ["configure", "build", "install"]
|
||||
|
||||
def flag_handler(self, name, flags):
|
||||
@ -850,17 +869,10 @@ def post_configure_fixes(self):
|
||||
with open(".tf_configure.bazelrc", mode="a") as f:
|
||||
f.write('build --action_env LD_LIBRARY_PATH="' + slibs + '"')
|
||||
|
||||
if spec.satisfies("@2.16.1-rocm-enhanced +rocm"):
|
||||
if os.path.exists(spec["llvm-amdgpu"].prefix.bin.clang):
|
||||
filter_file(
|
||||
"/usr/lib/llvm-17/bin/clang", spec["llvm-amdgpu"].prefix.bin.clang, ".bazelrc"
|
||||
)
|
||||
else:
|
||||
filter_file(
|
||||
"/usr/lib/llvm-17/bin/clang",
|
||||
spec["llvm-amdgpu"].prefix.llvm.bin.clang,
|
||||
".bazelrc",
|
||||
)
|
||||
if spec.satisfies("+rocm"):
|
||||
before = r"/usr/lib/llvm-\d+/bin/clang"
|
||||
after = spec["llvm-amdgpu"].prefix.bin.clang
|
||||
filter_file(before, after, ".bazelrc")
|
||||
|
||||
filter_file("build:opt --copt=-march=native", "", ".tf_configure.bazelrc")
|
||||
filter_file("build:opt --host_copt=-march=native", "", ".tf_configure.bazelrc")
|
||||
@ -938,6 +950,11 @@ def build(self, spec, prefix):
|
||||
|
||||
args.append("--config=v2")
|
||||
|
||||
if self.spec.satisfies("@2.18.0-rocm-enhanced: +rocm"):
|
||||
buildpath = join_path(
|
||||
self.stage.source_path, "bazel-bin/tensorflow/tools/pip_package/wheel_house/"
|
||||
)
|
||||
args.append(f"--repo_env=OUTPUT_PATH={buildpath}")
|
||||
# https://github.com/tensorflow/tensorflow/issues/63298
|
||||
if self.spec.satisfies("@2.17:"):
|
||||
args.append("//tensorflow/tools/pip_package:wheel")
|
||||
|
@ -0,0 +1,22 @@
|
||||
diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
|
||||
index f574a8da8fd..fc1fbf68bf8 100644
|
||||
--- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl
|
||||
+++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl
|
||||
@@ -360,7 +360,7 @@ def _gen_kernel_library(
|
||||
extra_args = extra_args,
|
||||
host_triple = host_triple,
|
||||
gpu_archs = gpu_archs,
|
||||
- jit = jit,
|
||||
+ jit = True,
|
||||
mlir_op = "{op}_{name}_{platform}_{type}_{output_type}.mlir".format(
|
||||
op = op,
|
||||
name = name,
|
||||
@@ -370,7 +370,7 @@ def _gen_kernel_library(
|
||||
),
|
||||
tile_size = typed_tile_size,
|
||||
unroll_factors = typed_unroll_factors,
|
||||
- jit_i64_indexed_for_large_tensors = jit_i64_indexed_for_large_tensors,
|
||||
+ jit_i64_indexed_for_large_tensors = False,
|
||||
)
|
||||
|
||||
# We have to use a sh_test instead of build_test because it doesn't properly find the dependent targets.
|
Loading…
Reference in New Issue
Block a user