spec.py: fix ArchSpec.intersects (#48741)

fixes a bug where `x86_64:` and `ppc64le:` intersected, and x86_64: and :haswell did not.
This commit is contained in:
Massimiliano Culpo 2025-01-28 16:46:09 +01:00 committed by GitHub
parent 82e091e2c2
commit 40a1da4a73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 36 deletions

View File

@ -461,6 +461,9 @@ def _target_satisfies(self, other: "ArchSpec", strict: bool) -> bool:
return bool(self._target_intersection(other)) return bool(self._target_intersection(other))
def _target_constrain(self, other: "ArchSpec") -> bool: def _target_constrain(self, other: "ArchSpec") -> bool:
if self.target is None and other.target is None:
return False
if not other._target_satisfies(self, strict=False): if not other._target_satisfies(self, strict=False):
raise UnsatisfiableArchitectureSpecError(self, other) raise UnsatisfiableArchitectureSpecError(self, other)
@ -509,21 +512,56 @@ def _target_intersection(self, other):
if (not s_min or o_comp >= s_min) and (not s_max or o_comp <= s_max): if (not s_min or o_comp >= s_min) and (not s_max or o_comp <= s_max):
results.append(o_min) results.append(o_min)
else: else:
# Take intersection of two ranges # Take the "min" of the two max, if there is a partial ordering.
# Lots of comparisons needed n_max = ""
_s_min = _make_microarchitecture(s_min) if s_max and o_max:
_s_max = _make_microarchitecture(s_max) _s_max = _make_microarchitecture(s_max)
_o_min = _make_microarchitecture(o_min) _o_max = _make_microarchitecture(o_max)
_o_max = _make_microarchitecture(o_max) if _s_max.family != _o_max.family:
continue
if _s_max <= _o_max:
n_max = s_max
elif _o_max < _s_max:
n_max = o_max
else:
continue
elif s_max:
n_max = s_max
elif o_max:
n_max = o_max
# Take the "max" of the two min.
n_min = ""
if s_min and o_min:
_s_min = _make_microarchitecture(s_min)
_o_min = _make_microarchitecture(o_min)
if _s_min.family != _o_min.family:
continue
if _s_min >= _o_min:
n_min = s_min
elif _o_min > _s_min:
n_min = o_min
else:
continue
elif s_min:
n_min = s_min
elif o_min:
n_min = o_min
if n_min and n_max:
_n_min = _make_microarchitecture(n_min)
_n_max = _make_microarchitecture(n_max)
if _n_min.family != _n_max.family or not _n_min <= _n_max:
continue
if n_min == n_max:
results.append(n_min)
else:
results.append(f"{n_min}:{n_max}")
elif n_min:
results.append(f"{n_min}:")
elif n_max:
results.append(f":{n_max}")
n_min = s_min if _s_min >= _o_min else o_min
n_max = s_max if _s_max <= _o_max else o_max
_n_min = _make_microarchitecture(n_min)
_n_max = _make_microarchitecture(n_max)
if _n_min == _n_max:
results.append(n_min)
elif not n_min or not n_max or _n_min < _n_max:
results.append("%s:%s" % (n_min, n_max))
return results return results
def constrain(self, other: "ArchSpec") -> bool: def constrain(self, other: "ArchSpec") -> bool:
@ -3151,18 +3189,13 @@ def constrain(self, other, deps=True):
if not self.variants[v].compatible(other.variants[v]): if not self.variants[v].compatible(other.variants[v]):
raise vt.UnsatisfiableVariantSpecError(self.variants[v], other.variants[v]) raise vt.UnsatisfiableVariantSpecError(self.variants[v], other.variants[v])
# TODO: Check out the logic here
sarch, oarch = self.architecture, other.architecture sarch, oarch = self.architecture, other.architecture
if sarch is not None and oarch is not None: if (
if sarch.platform is not None and oarch.platform is not None: sarch is not None
if sarch.platform != oarch.platform: and oarch is not None
raise UnsatisfiableArchitectureSpecError(sarch, oarch) and not self.architecture.intersects(other.architecture)
if sarch.os is not None and oarch.os is not None: ):
if sarch.os != oarch.os: raise UnsatisfiableArchitectureSpecError(sarch, oarch)
raise UnsatisfiableArchitectureSpecError(sarch, oarch)
if sarch.target is not None and oarch.target is not None:
if sarch.target != oarch.target:
raise UnsatisfiableArchitectureSpecError(sarch, oarch)
changed = False changed = False
@ -3185,18 +3218,12 @@ def constrain(self, other, deps=True):
changed |= self.compiler_flags.constrain(other.compiler_flags) changed |= self.compiler_flags.constrain(other.compiler_flags)
old = str(self.architecture)
sarch, oarch = self.architecture, other.architecture sarch, oarch = self.architecture, other.architecture
if sarch is None or other.architecture is None: if sarch is not None and oarch is not None:
self.architecture = sarch or oarch changed |= self.architecture.constrain(other.architecture)
else: elif oarch is not None:
if sarch.platform is None or oarch.platform is None: self.architecture = oarch
self.architecture.platform = sarch.platform or oarch.platform changed = True
if sarch.os is None or oarch.os is None:
sarch.os = sarch.os or oarch.os
if sarch.target is None or oarch.target is None:
sarch.target = sarch.target or oarch.target
changed |= str(self.architecture) != old
if deps: if deps:
changed |= self._constrain_dependencies(other) changed |= self._constrain_dependencies(other)

View File

@ -1834,6 +1834,16 @@ def test_abstract_contains_semantic(lhs, rhs, expected, mock_packages):
# Different virtuals intersect if there is at least package providing both # Different virtuals intersect if there is at least package providing both
(Spec, "mpi", "lapack", (True, False, False)), (Spec, "mpi", "lapack", (True, False, False)),
(Spec, "mpi", "pkgconfig", (False, False, False)), (Spec, "mpi", "pkgconfig", (False, False, False)),
# Intersection among target ranges for different architectures
(Spec, "target=x86_64:", "target=ppc64le:", (False, False, False)),
(Spec, "target=x86_64:", "target=:power9", (False, False, False)),
(Spec, "target=:haswell", "target=:power9", (False, False, False)),
(Spec, "target=:haswell", "target=ppc64le:", (False, False, False)),
# Intersection among target ranges for the same architecture
(Spec, "target=:haswell", "target=x86_64:", (True, True, True)),
(Spec, "target=:haswell", "target=x86_64_v4:", (False, False, False)),
# Edge case of uarch that split in a diamond structure, from a common ancestor
(Spec, "target=:cascadelake", "target=:cannonlake", (False, False, False)),
], ],
) )
def test_intersects_and_satisfies(factory, lhs_str, rhs_str, results): def test_intersects_and_satisfies(factory, lhs_str, rhs_str, results):
@ -1883,6 +1893,16 @@ def test_intersects_and_satisfies(factory, lhs_str, rhs_str, results):
# Flags # Flags
(Spec, "cppflags=-foo", "cppflags=-foo", False, "cppflags=-foo"), (Spec, "cppflags=-foo", "cppflags=-foo", False, "cppflags=-foo"),
(Spec, "cppflags=-foo", "cflags=-foo", True, "cppflags=-foo cflags=-foo"), (Spec, "cppflags=-foo", "cflags=-foo", True, "cppflags=-foo cflags=-foo"),
# Target ranges
(Spec, "target=x86_64:", "target=x86_64:", False, "target=x86_64:"),
(Spec, "target=x86_64:", "target=:haswell", True, "target=x86_64:haswell"),
(
Spec,
"target=x86_64:haswell",
"target=x86_64_v2:icelake",
True,
"target=x86_64_v2:haswell",
),
], ],
) )
def test_constrain(factory, lhs_str, rhs_str, result, constrained_str): def test_constrain(factory, lhs_str, rhs_str, result, constrained_str):