versions: simplify list if union not disjoint (#42902)

Spack merges ranges and concrete versions if they have non-empty
intersection. That is not enough for adjacent version ranges.

This commit ensures that disjoint ranges in version lists are simplified
if their union is not disjoint:

```python
"@1.0:2.0,2.1,2.2:3,4:6" # simplifies to "@1.0:6"
```
This commit is contained in:
Harmen Stoppels 2024-02-28 16:33:25 +01:00 committed by GitHub
parent 287e1039f5
commit 661ae1f230
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 25 deletions

View File

@ -906,6 +906,13 @@ def test_version_list_normalization():
assert ver("1.0:2.0,=1.0,ref=1.0") == ver(["1.0:2.0"]) assert ver("1.0:2.0,=1.0,ref=1.0") == ver(["1.0:2.0"])
def test_version_list_connected_union_of_disjoint_ranges():
# Make sure that we also simplify lists of ranges if their intersection is empty, but their
# union is connected.
assert ver("1.0:2.0,2.1,2.2:3,4:6") == ver(["1.0:6"])
assert ver("1.0:1.2,1.3:2") == ver("1.0:1.5,1.6:2")
@pytest.mark.parametrize("version", ["=1.2", "git.ref=1.2", "1.2"]) @pytest.mark.parametrize("version", ["=1.2", "git.ref=1.2", "1.2"])
def test_version_comparison_with_list_fails(version): def test_version_comparison_with_list_fails(version):
vlist = VersionList(["=1.3"]) vlist = VersionList(["=1.3"])

View File

@ -695,26 +695,35 @@ def satisfies(self, other: Union["ClosedOpenRange", ConcreteVersion, "VersionLis
def overlaps(self, other: Union["ClosedOpenRange", ConcreteVersion, "VersionList"]) -> bool: def overlaps(self, other: Union["ClosedOpenRange", ConcreteVersion, "VersionList"]) -> bool:
return self.intersects(other) return self.intersects(other)
def union(self, other: Union["ClosedOpenRange", ConcreteVersion, "VersionList"]): def _union_if_not_disjoint(
self, other: Union["ClosedOpenRange", ConcreteVersion]
) -> Optional["ClosedOpenRange"]:
"""Same as union, but returns None when the union is not connected. This function is not
implemented for version lists as right-hand side, as that makes little sense."""
if isinstance(other, StandardVersion): if isinstance(other, StandardVersion):
return self if self.lo <= other < self.hi else VersionList([self, other]) return self if self.lo <= other < self.hi else None
if isinstance(other, GitVersion): if isinstance(other, GitVersion):
return self if self.lo <= other.ref_version < self.hi else VersionList([self, other]) return self if self.lo <= other.ref_version < self.hi else None
if isinstance(other, ClosedOpenRange): if isinstance(other, ClosedOpenRange):
# Notice <= cause we want union(1:2, 3:4) = 1:4. # Notice <= cause we want union(1:2, 3:4) = 1:4.
if self.lo <= other.hi and other.lo <= self.hi: return (
return ClosedOpenRange(min(self.lo, other.lo), max(self.hi, other.hi)) ClosedOpenRange(min(self.lo, other.lo), max(self.hi, other.hi))
if self.lo <= other.hi and other.lo <= self.hi
else None
)
return VersionList([self, other]) raise TypeError(f"Unexpected type {type(other)}")
def union(self, other: Union["ClosedOpenRange", ConcreteVersion, "VersionList"]):
if isinstance(other, VersionList): if isinstance(other, VersionList):
v = other.copy() v = other.copy()
v.add(self) v.add(self)
return v return v
raise ValueError(f"Unexpected type {type(other)}") result = self._union_if_not_disjoint(other)
return result if result is not None else VersionList([self, other])
def intersection(self, other: Union["ClosedOpenRange", ConcreteVersion]): def intersection(self, other: Union["ClosedOpenRange", ConcreteVersion]):
# range - version -> singleton or nothing. # range - version -> singleton or nothing.
@ -732,8 +741,9 @@ class VersionList:
def __init__(self, vlist=None): def __init__(self, vlist=None):
self.versions: List[StandardVersion, GitVersion, ClosedOpenRange] = [] self.versions: List[StandardVersion, GitVersion, ClosedOpenRange] = []
if vlist is not None: if vlist is None:
if isinstance(vlist, str): pass
elif isinstance(vlist, str):
vlist = from_string(vlist) vlist = from_string(vlist)
if isinstance(vlist, VersionList): if isinstance(vlist, VersionList):
self.versions = vlist.versions self.versions = vlist.versions
@ -743,8 +753,8 @@ def __init__(self, vlist=None):
for v in vlist: for v in vlist:
self.add(ver(v)) self.add(ver(v))
def add(self, item): def add(self, item: Union[StandardVersion, GitVersion, ClosedOpenRange, "VersionList"]):
if isinstance(item, ConcreteVersion): if isinstance(item, (StandardVersion, GitVersion)):
i = bisect_left(self, item) i = bisect_left(self, item)
# Only insert when prev and next are not intersected. # Only insert when prev and next are not intersected.
if (i == 0 or not item.intersects(self[i - 1])) and ( if (i == 0 or not item.intersects(self[i - 1])) and (
@ -755,16 +765,22 @@ def add(self, item):
elif isinstance(item, ClosedOpenRange): elif isinstance(item, ClosedOpenRange):
i = bisect_left(self, item) i = bisect_left(self, item)
# Note: can span multiple concrete versions to the left, # Note: can span multiple concrete versions to the left (as well as to the right).
# For instance insert 1.2: into [1.2, hash=1.2, 1.3] # For instance insert 1.2: into [1.2, hash=1.2, 1.3, 1.4:1.5]
# would bisect to i = 1. # would bisect at i = 1 and merge i = 0 too.
while i > 0 and item.intersects(self[i - 1]): while i > 0:
item = item.union(self[i - 1]) union = item._union_if_not_disjoint(self[i - 1])
if union is None: # disjoint
break
item = union
del self.versions[i - 1] del self.versions[i - 1]
i -= 1 i -= 1
while i < len(self) and item.intersects(self[i]): while i < len(self):
item = item.union(self[i]) union = item._union_if_not_disjoint(self[i])
if union is None:
break
item = union
del self.versions[i] del self.versions[i]
self.versions.insert(i, item) self.versions.insert(i, item)