refactor: Index provided virtuals by when spec

Part 4 of reworking all package metadata to key by `when` conditions.

Changes conflict dictionary structure from this:

    { provided_spec: {when_spec, ...} }

to this:

    { when_spec: {provided_spec, ...} }
This commit is contained in:
Todd Gamblin 2023-06-21 00:20:32 -07:00
parent 7994caaeda
commit 6753cc0b81
7 changed files with 64 additions and 51 deletions

View File

@ -702,15 +702,13 @@ def _unknown_variants_in_directives(pkgs, error_cls):
)
)
# Check "patch" directive
for _, triggers in pkg_cls.provided.items():
triggers = [spack.spec.Spec(x) for x in triggers]
for vrn in triggers:
errors.extend(
_analyze_variants_in_directive(
pkg_cls, vrn, directive="patch", error_cls=error_cls
)
# Check "provides" directive
for when_spec in pkg_cls.provided:
errors.extend(
_analyze_variants_in_directive(
pkg_cls, when_spec, directive="provides", error_cls=error_cls
)
)
# Check "resource" directive
for vrn in pkg_cls.resources:
@ -752,6 +750,18 @@ def _issues_in_depends_on_directive(pkgs, error_cls):
]
errors.append(error_cls(summary=summary, details=details))
def check_virtual_with_variants(spec, msg):
if not spec.virtual or not spec.variants:
return
error = error_cls(
f"{pkg_name}: {msg}",
f"remove variants from '{spec}' in depends_on directive in {filename}",
)
errors.append(error)
check_virtual_with_variants(dep.spec, "virtual dependency cannot have variants")
check_virtual_with_variants(dep.spec, "virtual when= spec cannot have variants")
# No need to analyze virtual packages
if spack.repo.PATH.is_virtual(dep_name):
continue
@ -963,9 +973,11 @@ def _extracts_errors(triggers, summary):
summary = f"{pkg_name}: wrong 'when=' condition for the '{vname}' variant"
errors.extend(_extracts_errors(triggers, summary))
for provided, triggers in pkg_cls.provided.items():
summary = f"{pkg_name}: wrong 'when=' condition for the '{provided}' virtual"
errors.extend(_extracts_errors(triggers, summary))
for when, providers, details in _error_items(pkg_cls.provided):
errors.extend(
error_cls(f"{pkg_name}: wrong 'when=' condition for '{provided}' virtual", details)
for provided in providers
)
for when, requirements, details in _error_items(pkg_cls.requirements):
errors.append(

View File

@ -474,13 +474,7 @@ def print_virtuals(pkg, args):
color.cprint("")
color.cprint(section_title("Virtual Packages: "))
if pkg.provided:
inverse_map = {}
for spec, whens in pkg.provided.items():
for when in whens:
if when not in inverse_map:
inverse_map[when] = set()
inverse_map[when].add(spec)
for when, specs in reversed(sorted(inverse_map.items())):
for when, specs in reversed(sorted(pkg.provided.items())):
line = " %s provides %s" % (
when.colorized(),
", ".join(s.colorized() for s in specs),

View File

@ -613,7 +613,7 @@ def _execute_extends(pkg):
@directive(dicts=("provided", "provided_together"))
def provides(*specs, when: Optional[str] = None):
def provides(*specs: SpecType, when: WhenType = None):
"""Allows packages to provide a virtual dependency.
If a package provides "mpi", other packages can declare that they depend on "mpi",
@ -624,7 +624,7 @@ def provides(*specs, when: Optional[str] = None):
when: condition when this provides clause needs to be considered
"""
def _execute_provides(pkg):
def _execute_provides(pkg: "spack.package_base.PackageBase"):
import spack.parser # Avoid circular dependency
when_spec = _make_when_spec(when)
@ -634,6 +634,7 @@ def _execute_provides(pkg):
# ``when`` specs for ``provides()`` need a name, as they are used
# to build the ProviderIndex.
when_spec.name = pkg.name
spec_objs = [spack.spec.Spec(x) for x in specs]
spec_names = [x.name for x in spec_objs]
if len(spec_names) > 1:
@ -643,9 +644,8 @@ def _execute_provides(pkg):
if pkg.name == provided_spec.name:
raise CircularReferenceError("Package '%s' cannot provide itself." % pkg.name)
if provided_spec not in pkg.provided:
pkg.provided[provided_spec] = set()
pkg.provided[provided_spec].add(when_spec)
provided_set = pkg.provided.setdefault(when_spec, set())
provided_set.add(provided_spec)
return _execute_provides

View File

@ -25,7 +25,7 @@
import time
import traceback
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union
import llnl.util.filesystem as fsys
import llnl.util.tty as tty
@ -565,6 +565,8 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
requirements: Dict[
"spack.spec.Spec", List[Tuple[Tuple["spack.spec.Spec", ...], str, Optional[str]]]
]
provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
provided_together: Dict["spack.spec.Spec", List[Set[str]]]
patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
#: By default, packages are not virtual
@ -1342,9 +1344,9 @@ def provides(self, vpkg_name):
True if this package provides a virtual package with the specified name
"""
return any(
any(self.spec.intersects(c) for c in constraints)
for s, constraints in self.provided.items()
if s.name == vpkg_name
any(spec.name == vpkg_name for spec in provided)
for when_spec, provided in self.provided.items()
if self.spec.intersects(when_spec)
)
@property
@ -1354,10 +1356,16 @@ def virtuals_provided(self):
"""
return [
vspec
for vspec, constraints in self.provided.items()
if any(self.spec.satisfies(c) for c in constraints)
for when_spec, provided in self.provided.items()
for vspec in provided
if self.spec.satisfies(when_spec)
]
@classmethod
def provided_virtual_names(cls):
"""Return sorted list of names of virtuals that can be provided by this package."""
return sorted(set(vpkg.name for virtuals in cls.provided.values() for vpkg in virtuals))
@property
def prefix(self):
"""Get the prefix into which this package should be installed."""

View File

@ -128,8 +128,8 @@ def update(self, spec):
assert not self.repository.is_virtual_safe(spec.name), msg
pkg_provided = self.repository.get_pkg_class(spec.name).provided
for provided_spec, provider_specs in pkg_provided.items():
for provider_spec_readonly in provider_specs:
for provider_spec_readonly, provided_specs in pkg_provided.items():
for provided_spec in provided_specs:
# TODO: fix this comment.
# We want satisfaction other than flags
provider_spec = provider_spec_readonly.copy()

View File

@ -1628,19 +1628,20 @@ def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
self.gen.fact(fn.imposed_constraint(condition_id, *pred.args))
def package_provider_rules(self, pkg):
for provider_name in sorted(set(s.name for s in pkg.provided.keys())):
if provider_name not in self.possible_virtuals:
for vpkg_name in pkg.provided_virtual_names():
if vpkg_name not in self.possible_virtuals:
continue
self.gen.fact(fn.pkg_fact(pkg.name, fn.possible_provider(provider_name)))
self.gen.fact(fn.pkg_fact(pkg.name, fn.possible_provider(vpkg_name)))
for provided, whens in pkg.provided.items():
if provided.name not in self.possible_virtuals:
continue
for when in whens:
msg = "%s provides %s when %s" % (pkg.name, provided, when)
condition_id = self.condition(when, provided, pkg.name, msg)
for when, provided in pkg.provided.items():
for vpkg in provided:
if vpkg.name not in self.possible_virtuals:
continue
msg = f"{pkg.name} provides {vpkg} when {when}"
condition_id = self.condition(when, vpkg, pkg.name, msg)
self.gen.fact(
fn.pkg_fact(when.name, fn.provider_condition(condition_id, provided.name))
fn.pkg_fact(when.name, fn.provider_condition(condition_id, vpkg.name))
)
self.gen.newline()
@ -3383,7 +3384,7 @@ def _is_reusable(spec: spack.spec.Spec, packages, local: bool) -> bool:
return True
try:
provided = [p.name for p in spec.package.provided]
provided = spack.repo.PATH.get(spec).provided_virtual_names()
except spack.repo.RepoError:
provided = []

View File

@ -2788,7 +2788,7 @@ def _old_concretize(self, tests=False, deprecation_warning=True):
for dep in self.traverse():
visited_user_specs.add(dep.name)
pkg_cls = spack.repo.PATH.get_pkg_class(dep.name)
visited_user_specs.update(x.name for x in pkg_cls(dep).provided)
visited_user_specs.update(pkg_cls(dep).provided_virtual_names())
extra = set(user_spec_deps.keys()).difference(visited_user_specs)
if extra:
@ -3774,11 +3774,9 @@ def intersects(self, other: Union[str, "Spec"], deps: bool = True) -> bool:
return False
if pkg.provides(virtual_spec.name):
for provided, when_specs in pkg.provided.items():
if any(
non_virtual_spec.intersects(when, deps=False) for when in when_specs
):
if provided.intersects(virtual_spec):
for when_spec, provided in pkg.provided.items():
if non_virtual_spec.intersects(when_spec, deps=False):
if any(vpkg.intersects(virtual_spec) for vpkg in provided):
return True
return False
@ -3881,9 +3879,9 @@ def satisfies(self, other: Union[str, "Spec"], deps: bool = True) -> bool:
return False
if pkg.provides(other.name):
for provided, when_specs in pkg.provided.items():
if any(self.satisfies(when, deps=False) for when in when_specs):
if provided.intersects(other):
for when_spec, provided in pkg.provided.items():
if self.satisfies(when_spec, deps=False):
if any(vpkg.intersects(other) for vpkg in provided):
return True
return False