diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py index 8214f531d9e..66a2258983a 100644 --- a/lib/spack/spack/solver/asp.py +++ b/lib/spack/spack/solver/asp.py @@ -3834,7 +3834,7 @@ def virtual_on_edge(self, parent_node, provider_node, virtual): provider_spec = self._specs[provider_node] dependencies = [x for x in dependencies if id(x.spec) == id(provider_spec)] assert len(dependencies) == 1, f"{virtual}: {provider_node.pkg}" - dependencies[0].update_virtuals((virtual,)) + dependencies[0].update_virtuals(virtual) def reorder_flags(self): """For each spec, determine the order of compiler flags applied to it. diff --git a/lib/spack/spack/spec.py b/lib/spack/spack/spec.py index eb546abbc7d..17fa574e07e 100644 --- a/lib/spack/spack/spec.py +++ b/lib/spack/spack/spec.py @@ -754,11 +754,17 @@ def update_deptypes(self, depflag: dt.DepFlag) -> bool: self.depflag = new return True - def update_virtuals(self, virtuals: Iterable[str]) -> bool: + def update_virtuals(self, virtuals: Union[str, Iterable[str]]) -> bool: """Update the list of provided virtuals""" old = self.virtuals - self.virtuals = tuple(sorted(set(virtuals).union(self.virtuals))) - return old != self.virtuals + if isinstance(virtuals, str): + union = {virtuals, *self.virtuals} + else: + union = {*virtuals, *self.virtuals} + if len(union) == len(old): + return False + self.virtuals = tuple(sorted(union)) + return True def copy(self) -> "DependencySpec": """Return a copy of this edge""" @@ -1041,7 +1047,7 @@ def select( parent: name of the parent package child: name of the child package depflag: allowed dependency types in flag form - virtuals: list of virtuals on the edge + virtuals: list of virtuals or specific virtual on the edge """ if not depflag: return [] @@ -1590,7 +1596,11 @@ def _get_dependency(self, name): return deps[0] def edges_from_dependents( - self, name=None, depflag: dt.DepFlag = dt.ALL, *, virtuals: Optional[List[str]] = None + self, + name=None, + depflag: dt.DepFlag = dt.ALL, + *, + virtuals: Optional[Union[str, Sequence[str]]] = None, ) -> List[DependencySpec]: """Return a list of edges connecting this node in the DAG to parents. diff --git a/lib/spack/spack/test/spec_semantics.py b/lib/spack/spack/test/spec_semantics.py index ba5691bc58c..eee7791baa4 100644 --- a/lib/spack/spack/test/spec_semantics.py +++ b/lib/spack/spack/test/spec_semantics.py @@ -1948,6 +1948,18 @@ def test_edge_equality_does_not_depend_on_virtual_order(): assert tuple(sorted(edge2.virtuals)) == edge1.virtuals +def test_update_virtuals(): + parent, child = Spec("parent"), Spec("child") + edge = DependencySpec(parent, child, depflag=0, virtuals=("mpi", "lapack")) + assert edge.update_virtuals("blas") + assert edge.virtuals == ("blas", "lapack", "mpi") + assert edge.update_virtuals(("c", "fortran", "mpi", "lapack")) + assert edge.virtuals == ("blas", "c", "fortran", "lapack", "mpi") + assert not edge.update_virtuals("mpi") + assert not edge.update_virtuals(("c", "fortran", "mpi", "lapack")) + assert edge.virtuals == ("blas", "c", "fortran", "lapack", "mpi") + + def test_virtual_queries_work_for_strings_and_lists(): """Ensure that ``dependencies()`` works with both virtuals=str and virtuals=[str, ...].""" parent, child = Spec("parent"), Spec("child")