refactor: Index provided virtuals by when spec

Part 4 of reworking all package metadata to key by `when` conditions.

Changes conflict dictionary structure from this:

    { provided_spec: {when_spec, ...} }

to this:

    { when_spec: {provided_spec, ...} }
This commit is contained in:
Todd Gamblin 2023-06-21 00:20:32 -07:00
parent 7994caaeda
commit 6753cc0b81
7 changed files with 64 additions and 51 deletions

View File

@ -702,15 +702,13 @@ def _unknown_variants_in_directives(pkgs, error_cls):
) )
) )
# Check "patch" directive # Check "provides" directive
for _, triggers in pkg_cls.provided.items(): for when_spec in pkg_cls.provided:
triggers = [spack.spec.Spec(x) for x in triggers] errors.extend(
for vrn in triggers: _analyze_variants_in_directive(
errors.extend( pkg_cls, when_spec, directive="provides", error_cls=error_cls
_analyze_variants_in_directive(
pkg_cls, vrn, directive="patch", error_cls=error_cls
)
) )
)
# Check "resource" directive # Check "resource" directive
for vrn in pkg_cls.resources: for vrn in pkg_cls.resources:
@ -752,6 +750,18 @@ def _issues_in_depends_on_directive(pkgs, error_cls):
] ]
errors.append(error_cls(summary=summary, details=details)) errors.append(error_cls(summary=summary, details=details))
def check_virtual_with_variants(spec, msg):
if not spec.virtual or not spec.variants:
return
error = error_cls(
f"{pkg_name}: {msg}",
f"remove variants from '{spec}' in depends_on directive in {filename}",
)
errors.append(error)
check_virtual_with_variants(dep.spec, "virtual dependency cannot have variants")
check_virtual_with_variants(dep.spec, "virtual when= spec cannot have variants")
# No need to analyze virtual packages # No need to analyze virtual packages
if spack.repo.PATH.is_virtual(dep_name): if spack.repo.PATH.is_virtual(dep_name):
continue continue
@ -963,9 +973,11 @@ def _extracts_errors(triggers, summary):
summary = f"{pkg_name}: wrong 'when=' condition for the '{vname}' variant" summary = f"{pkg_name}: wrong 'when=' condition for the '{vname}' variant"
errors.extend(_extracts_errors(triggers, summary)) errors.extend(_extracts_errors(triggers, summary))
for provided, triggers in pkg_cls.provided.items(): for when, providers, details in _error_items(pkg_cls.provided):
summary = f"{pkg_name}: wrong 'when=' condition for the '{provided}' virtual" errors.extend(
errors.extend(_extracts_errors(triggers, summary)) error_cls(f"{pkg_name}: wrong 'when=' condition for '{provided}' virtual", details)
for provided in providers
)
for when, requirements, details in _error_items(pkg_cls.requirements): for when, requirements, details in _error_items(pkg_cls.requirements):
errors.append( errors.append(

View File

@ -474,13 +474,7 @@ def print_virtuals(pkg, args):
color.cprint("") color.cprint("")
color.cprint(section_title("Virtual Packages: ")) color.cprint(section_title("Virtual Packages: "))
if pkg.provided: if pkg.provided:
inverse_map = {} for when, specs in reversed(sorted(pkg.provided.items())):
for spec, whens in pkg.provided.items():
for when in whens:
if when not in inverse_map:
inverse_map[when] = set()
inverse_map[when].add(spec)
for when, specs in reversed(sorted(inverse_map.items())):
line = " %s provides %s" % ( line = " %s provides %s" % (
when.colorized(), when.colorized(),
", ".join(s.colorized() for s in specs), ", ".join(s.colorized() for s in specs),

View File

@ -613,7 +613,7 @@ def _execute_extends(pkg):
@directive(dicts=("provided", "provided_together")) @directive(dicts=("provided", "provided_together"))
def provides(*specs, when: Optional[str] = None): def provides(*specs: SpecType, when: WhenType = None):
"""Allows packages to provide a virtual dependency. """Allows packages to provide a virtual dependency.
If a package provides "mpi", other packages can declare that they depend on "mpi", If a package provides "mpi", other packages can declare that they depend on "mpi",
@ -624,7 +624,7 @@ def provides(*specs, when: Optional[str] = None):
when: condition when this provides clause needs to be considered when: condition when this provides clause needs to be considered
""" """
def _execute_provides(pkg): def _execute_provides(pkg: "spack.package_base.PackageBase"):
import spack.parser # Avoid circular dependency import spack.parser # Avoid circular dependency
when_spec = _make_when_spec(when) when_spec = _make_when_spec(when)
@ -634,6 +634,7 @@ def _execute_provides(pkg):
# ``when`` specs for ``provides()`` need a name, as they are used # ``when`` specs for ``provides()`` need a name, as they are used
# to build the ProviderIndex. # to build the ProviderIndex.
when_spec.name = pkg.name when_spec.name = pkg.name
spec_objs = [spack.spec.Spec(x) for x in specs] spec_objs = [spack.spec.Spec(x) for x in specs]
spec_names = [x.name for x in spec_objs] spec_names = [x.name for x in spec_objs]
if len(spec_names) > 1: if len(spec_names) > 1:
@ -643,9 +644,8 @@ def _execute_provides(pkg):
if pkg.name == provided_spec.name: if pkg.name == provided_spec.name:
raise CircularReferenceError("Package '%s' cannot provide itself." % pkg.name) raise CircularReferenceError("Package '%s' cannot provide itself." % pkg.name)
if provided_spec not in pkg.provided: provided_set = pkg.provided.setdefault(when_spec, set())
pkg.provided[provided_spec] = set() provided_set.add(provided_spec)
pkg.provided[provided_spec].add(when_spec)
return _execute_provides return _execute_provides

View File

@ -25,7 +25,7 @@
import time import time
import traceback import traceback
import warnings import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union
import llnl.util.filesystem as fsys import llnl.util.filesystem as fsys
import llnl.util.tty as tty import llnl.util.tty as tty
@ -565,6 +565,8 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
requirements: Dict[ requirements: Dict[
"spack.spec.Spec", List[Tuple[Tuple["spack.spec.Spec", ...], str, Optional[str]]] "spack.spec.Spec", List[Tuple[Tuple["spack.spec.Spec", ...], str, Optional[str]]]
] ]
provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
provided_together: Dict["spack.spec.Spec", List[Set[str]]]
patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]] patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
#: By default, packages are not virtual #: By default, packages are not virtual
@ -1342,9 +1344,9 @@ def provides(self, vpkg_name):
True if this package provides a virtual package with the specified name True if this package provides a virtual package with the specified name
""" """
return any( return any(
any(self.spec.intersects(c) for c in constraints) any(spec.name == vpkg_name for spec in provided)
for s, constraints in self.provided.items() for when_spec, provided in self.provided.items()
if s.name == vpkg_name if self.spec.intersects(when_spec)
) )
@property @property
@ -1354,10 +1356,16 @@ def virtuals_provided(self):
""" """
return [ return [
vspec vspec
for vspec, constraints in self.provided.items() for when_spec, provided in self.provided.items()
if any(self.spec.satisfies(c) for c in constraints) for vspec in provided
if self.spec.satisfies(when_spec)
] ]
@classmethod
def provided_virtual_names(cls):
"""Return sorted list of names of virtuals that can be provided by this package."""
return sorted(set(vpkg.name for virtuals in cls.provided.values() for vpkg in virtuals))
@property @property
def prefix(self): def prefix(self):
"""Get the prefix into which this package should be installed.""" """Get the prefix into which this package should be installed."""

View File

@ -128,8 +128,8 @@ def update(self, spec):
assert not self.repository.is_virtual_safe(spec.name), msg assert not self.repository.is_virtual_safe(spec.name), msg
pkg_provided = self.repository.get_pkg_class(spec.name).provided pkg_provided = self.repository.get_pkg_class(spec.name).provided
for provided_spec, provider_specs in pkg_provided.items(): for provider_spec_readonly, provided_specs in pkg_provided.items():
for provider_spec_readonly in provider_specs: for provided_spec in provided_specs:
# TODO: fix this comment. # TODO: fix this comment.
# We want satisfaction other than flags # We want satisfaction other than flags
provider_spec = provider_spec_readonly.copy() provider_spec = provider_spec_readonly.copy()

View File

@ -1628,19 +1628,20 @@ def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
self.gen.fact(fn.imposed_constraint(condition_id, *pred.args)) self.gen.fact(fn.imposed_constraint(condition_id, *pred.args))
def package_provider_rules(self, pkg): def package_provider_rules(self, pkg):
for provider_name in sorted(set(s.name for s in pkg.provided.keys())): for vpkg_name in pkg.provided_virtual_names():
if provider_name not in self.possible_virtuals: if vpkg_name not in self.possible_virtuals:
continue continue
self.gen.fact(fn.pkg_fact(pkg.name, fn.possible_provider(provider_name))) self.gen.fact(fn.pkg_fact(pkg.name, fn.possible_provider(vpkg_name)))
for provided, whens in pkg.provided.items(): for when, provided in pkg.provided.items():
if provided.name not in self.possible_virtuals: for vpkg in provided:
continue if vpkg.name not in self.possible_virtuals:
for when in whens: continue
msg = "%s provides %s when %s" % (pkg.name, provided, when)
condition_id = self.condition(when, provided, pkg.name, msg) msg = f"{pkg.name} provides {vpkg} when {when}"
condition_id = self.condition(when, vpkg, pkg.name, msg)
self.gen.fact( self.gen.fact(
fn.pkg_fact(when.name, fn.provider_condition(condition_id, provided.name)) fn.pkg_fact(when.name, fn.provider_condition(condition_id, vpkg.name))
) )
self.gen.newline() self.gen.newline()
@ -3383,7 +3384,7 @@ def _is_reusable(spec: spack.spec.Spec, packages, local: bool) -> bool:
return True return True
try: try:
provided = [p.name for p in spec.package.provided] provided = spack.repo.PATH.get(spec).provided_virtual_names()
except spack.repo.RepoError: except spack.repo.RepoError:
provided = [] provided = []

View File

@ -2788,7 +2788,7 @@ def _old_concretize(self, tests=False, deprecation_warning=True):
for dep in self.traverse(): for dep in self.traverse():
visited_user_specs.add(dep.name) visited_user_specs.add(dep.name)
pkg_cls = spack.repo.PATH.get_pkg_class(dep.name) pkg_cls = spack.repo.PATH.get_pkg_class(dep.name)
visited_user_specs.update(x.name for x in pkg_cls(dep).provided) visited_user_specs.update(pkg_cls(dep).provided_virtual_names())
extra = set(user_spec_deps.keys()).difference(visited_user_specs) extra = set(user_spec_deps.keys()).difference(visited_user_specs)
if extra: if extra:
@ -3774,11 +3774,9 @@ def intersects(self, other: Union[str, "Spec"], deps: bool = True) -> bool:
return False return False
if pkg.provides(virtual_spec.name): if pkg.provides(virtual_spec.name):
for provided, when_specs in pkg.provided.items(): for when_spec, provided in pkg.provided.items():
if any( if non_virtual_spec.intersects(when_spec, deps=False):
non_virtual_spec.intersects(when, deps=False) for when in when_specs if any(vpkg.intersects(virtual_spec) for vpkg in provided):
):
if provided.intersects(virtual_spec):
return True return True
return False return False
@ -3881,9 +3879,9 @@ def satisfies(self, other: Union[str, "Spec"], deps: bool = True) -> bool:
return False return False
if pkg.provides(other.name): if pkg.provides(other.name):
for provided, when_specs in pkg.provided.items(): for when_spec, provided in pkg.provided.items():
if any(self.satisfies(when, deps=False) for when in when_specs): if self.satisfies(when_spec, deps=False):
if provided.intersects(other): if any(vpkg.intersects(other) for vpkg in provided):
return True return True
return False return False