spec.py: more virtuals=... type hints (#49753)

Deal with the "issue" that passing a str instance does not cause a
type check failure, because str is a subset of Sequence[str] and
Iterable[str]. Instead fix it by special casing the str instance.
This commit is contained in:
Harmen Stoppels 2025-04-02 09:05:00 +02:00 committed by GitHub
parent ca64050f6a
commit 0facab231f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 6 deletions

View File

@ -3834,7 +3834,7 @@ def virtual_on_edge(self, parent_node, provider_node, virtual):
provider_spec = self._specs[provider_node]
dependencies = [x for x in dependencies if id(x.spec) == id(provider_spec)]
assert len(dependencies) == 1, f"{virtual}: {provider_node.pkg}"
dependencies[0].update_virtuals((virtual,))
dependencies[0].update_virtuals(virtual)
def reorder_flags(self):
"""For each spec, determine the order of compiler flags applied to it.

View File

@ -754,11 +754,17 @@ def update_deptypes(self, depflag: dt.DepFlag) -> bool:
self.depflag = new
return True
def update_virtuals(self, virtuals: Iterable[str]) -> bool:
def update_virtuals(self, virtuals: Union[str, Iterable[str]]) -> bool:
"""Update the list of provided virtuals"""
old = self.virtuals
self.virtuals = tuple(sorted(set(virtuals).union(self.virtuals)))
return old != self.virtuals
if isinstance(virtuals, str):
union = {virtuals, *self.virtuals}
else:
union = {*virtuals, *self.virtuals}
if len(union) == len(old):
return False
self.virtuals = tuple(sorted(union))
return True
def copy(self) -> "DependencySpec":
"""Return a copy of this edge"""
@ -1041,7 +1047,7 @@ def select(
parent: name of the parent package
child: name of the child package
depflag: allowed dependency types in flag form
virtuals: list of virtuals on the edge
virtuals: list of virtuals or specific virtual on the edge
"""
if not depflag:
return []
@ -1590,7 +1596,11 @@ def _get_dependency(self, name):
return deps[0]
def edges_from_dependents(
self, name=None, depflag: dt.DepFlag = dt.ALL, *, virtuals: Optional[List[str]] = None
self,
name=None,
depflag: dt.DepFlag = dt.ALL,
*,
virtuals: Optional[Union[str, Sequence[str]]] = None,
) -> List[DependencySpec]:
"""Return a list of edges connecting this node in the DAG
to parents.

View File

@ -1948,6 +1948,18 @@ def test_edge_equality_does_not_depend_on_virtual_order():
assert tuple(sorted(edge2.virtuals)) == edge1.virtuals
def test_update_virtuals():
parent, child = Spec("parent"), Spec("child")
edge = DependencySpec(parent, child, depflag=0, virtuals=("mpi", "lapack"))
assert edge.update_virtuals("blas")
assert edge.virtuals == ("blas", "lapack", "mpi")
assert edge.update_virtuals(("c", "fortran", "mpi", "lapack"))
assert edge.virtuals == ("blas", "c", "fortran", "lapack", "mpi")
assert not edge.update_virtuals("mpi")
assert not edge.update_virtuals(("c", "fortran", "mpi", "lapack"))
assert edge.virtuals == ("blas", "c", "fortran", "lapack", "mpi")
def test_virtual_queries_work_for_strings_and_lists():
"""Ensure that ``dependencies()`` works with both virtuals=str and virtuals=[str, ...]."""
parent, child = Spec("parent"), Spec("child")