refactor: add type annotations and refactor solver conditions (#42081)

Refactoring `SpackSolverSetup` is a bit easier with type annotations, so I started
adding some. This adds annotations for the (many) instance variables on
`SpackSolverSetup` as well as a few other places.

This also refactors `condition()` to reduce redundancy and to allow
`_get_condition_id()` to be called independently of the larger condition
function.


Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
This commit is contained in:
Todd Gamblin 2024-02-26 14:26:01 -08:00 committed by GitHub
parent c7df258ca6
commit 48088ee24a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 91 deletions

View File

@ -566,6 +566,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]] provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
provided_together: Dict["spack.spec.Spec", List[Set[str]]] provided_together: Dict["spack.spec.Spec", List[Set[str]]]
patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]] patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
variants: Dict[str, Tuple["spack.variant.Variant", "spack.spec.Spec"]]
#: By default, packages are not virtual #: By default, packages are not virtual
#: Virtual packages override this attribute #: Virtual packages override this attribute

View File

@ -15,7 +15,7 @@
import types import types
import typing import typing
import warnings import warnings
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Type, Union
import archspec.cpu import archspec.cpu
@ -258,7 +258,7 @@ def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunc
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts)) return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
def _create_counter(specs, tests): def _create_counter(specs: List[spack.spec.Spec], tests: bool):
strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none") strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none")
if strategy == "full": if strategy == "full":
return FullDuplicatesCounter(specs, tests=tests) return FullDuplicatesCounter(specs, tests=tests)
@ -897,35 +897,41 @@ def __iter__(self):
return iter(self.data) return iter(self.data)
# types for condition caching in solver setup
ConditionSpecKey = Tuple[str, Optional[TransformFunction]]
ConditionIdFunctionPair = Tuple[int, List[AspFunction]]
ConditionSpecCache = Dict[str, Dict[ConditionSpecKey, ConditionIdFunctionPair]]
class SpackSolverSetup: class SpackSolverSetup:
"""Class to set up and run a Spack concretization solve.""" """Class to set up and run a Spack concretization solve."""
def __init__(self, tests=False): def __init__(self, tests: bool = False):
self.gen = None # set by setup() # these are all initialized in setup()
self.gen: "ProblemInstanceBuilder" = ProblemInstanceBuilder()
self.possible_virtuals: Set[str] = set()
self.assumptions = [] self.assumptions: List[Tuple["clingo.Symbol", bool]] = [] # type: ignore[name-defined]
self.declared_versions = collections.defaultdict(list) self.declared_versions: Dict[str, List[DeclaredVersion]] = collections.defaultdict(list)
self.possible_versions = collections.defaultdict(set) self.possible_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(set)
self.deprecated_versions = collections.defaultdict(set) self.deprecated_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(
set
)
self.possible_virtuals = None self.possible_compilers: List = []
self.possible_compilers = [] self.possible_oses: Set = set()
self.possible_oses = set() self.variant_values_from_specs: Set = set()
self.variant_values_from_specs = set() self.version_constraints: Set = set()
self.version_constraints = set() self.target_constraints: Set = set()
self.target_constraints = set() self.default_targets: List = []
self.default_targets = [] self.compiler_version_constraints: Set = set()
self.compiler_version_constraints = set() self.post_facts: List = []
self.post_facts = []
# (ID, CompilerSpec) -> dictionary of attributes self.reusable_and_possible: ConcreteSpecsByHash = ConcreteSpecsByHash()
self.compiler_info = collections.defaultdict(dict)
self.reusable_and_possible = ConcreteSpecsByHash() self._id_counter: Iterator[int] = itertools.count()
self._trigger_cache: ConditionSpecCache = collections.defaultdict(dict)
self._id_counter = itertools.count() self._effect_cache: ConditionSpecCache = collections.defaultdict(dict)
self._trigger_cache = collections.defaultdict(dict)
self._effect_cache = collections.defaultdict(dict)
# Caches to optimize the setup phase of the solver # Caches to optimize the setup phase of the solver
self.target_specs_cache = None self.target_specs_cache = None
@ -937,8 +943,8 @@ def __init__(self, tests=False):
self.concretize_everything = True self.concretize_everything = True
# Set during the call to setup # Set during the call to setup
self.pkgs = None self.pkgs: Set[str] = set()
self.explicitly_required_namespaces = {} self.explicitly_required_namespaces: Dict[str, str] = {}
def pkg_version_rules(self, pkg): def pkg_version_rules(self, pkg):
"""Output declared versions of a package. """Output declared versions of a package.
@ -1222,6 +1228,38 @@ def variant_rules(self, pkg):
self.gen.newline() self.gen.newline()
def _get_condition_id(
self,
named_cond: spack.spec.Spec,
cache: ConditionSpecCache,
body: bool,
transform: Optional[TransformFunction] = None,
) -> int:
"""Get the id for one half of a condition (either a trigger or an imposed constraint).
Construct a key from the condition spec and any associated transformation, and
cache the ASP functions that they imply. The saved functions will be output
later in ``trigger_rules()`` and ``effect_rules()``.
Returns:
The id of the cached trigger or effect.
"""
pkg_cache = cache[named_cond.name]
named_cond_key = (str(named_cond), transform)
result = pkg_cache.get(named_cond_key)
if result:
return result[0]
cond_id = next(self._id_counter)
requirements = self.spec_clauses(named_cond, body=body)
if transform:
requirements = transform(named_cond, requirements)
pkg_cache[named_cond_key] = (cond_id, requirements)
return cond_id
def condition( def condition(
self, self,
required_spec: spack.spec.Spec, required_spec: spack.spec.Spec,
@ -1247,7 +1285,8 @@ def condition(
""" """
named_cond = required_spec.copy() named_cond = required_spec.copy()
named_cond.name = named_cond.name or name named_cond.name = named_cond.name or name
assert named_cond.name, "must provide name for anonymous conditions!" if not named_cond.name:
raise ValueError(f"Must provide a name for anonymous condition: '{named_cond}'")
# Check if we can emit the requirements before updating the condition ID counter. # Check if we can emit the requirements before updating the condition ID counter.
# In this way, if a condition can't be emitted but the exception is handled in the caller, # In this way, if a condition can't be emitted but the exception is handled in the caller,
@ -1257,35 +1296,19 @@ def condition(
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
self.gen.fact(fn.condition_reason(condition_id, msg)) self.gen.fact(fn.condition_reason(condition_id, msg))
cache = self._trigger_cache[named_cond.name] trigger_id = self._get_condition_id(
named_cond, cache=self._trigger_cache, body=True, transform=transform_required
named_cond_key = (str(named_cond), transform_required) )
if named_cond_key not in cache:
trigger_id = next(self._id_counter)
requirements = self.spec_clauses(named_cond, body=True, required_from=name)
if transform_required:
requirements = transform_required(named_cond, requirements)
cache[named_cond_key] = (trigger_id, requirements)
trigger_id, requirements = cache[named_cond_key]
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
if not imposed_spec: if not imposed_spec:
return condition_id return condition_id
cache = self._effect_cache[named_cond.name] effect_id = self._get_condition_id(
imposed_spec_key = (str(imposed_spec), transform_imposed) imposed_spec, cache=self._effect_cache, body=False, transform=transform_imposed
if imposed_spec_key not in cache: )
effect_id = next(self._id_counter)
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
if transform_imposed:
requirements = transform_imposed(imposed_spec, requirements)
cache[imposed_spec_key] = (effect_id, requirements)
effect_id, requirements = cache[imposed_spec_key]
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
return condition_id return condition_id
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False): def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
@ -1387,23 +1410,13 @@ def virtual_preferences(self, pkg_name, func):
def provider_defaults(self): def provider_defaults(self):
self.gen.h2("Default virtual providers") self.gen.h2("Default virtual providers")
msg = (
"Internal Error: possible_virtuals is not populated. Please report to the spack"
" maintainers"
)
assert self.possible_virtuals is not None, msg
self.virtual_preferences( self.virtual_preferences(
"all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i)) "all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i))
) )
def provider_requirements(self): def provider_requirements(self):
self.gen.h2("Requirements on virtual providers") self.gen.h2("Requirements on virtual providers")
msg = (
"Internal Error: possible_virtuals is not populated. Please report to the spack"
" maintainers"
)
parser = RequirementParser(spack.config.CONFIG) parser = RequirementParser(spack.config.CONFIG)
assert self.possible_virtuals is not None, msg
for virtual_str in sorted(self.possible_virtuals): for virtual_str in sorted(self.possible_virtuals):
rules = parser.rules_from_virtual(virtual_str) rules = parser.rules_from_virtual(virtual_str)
if rules: if rules:
@ -1602,35 +1615,57 @@ def flag_defaults(self):
fn.compiler_version_flag(compiler.name, compiler.version, name, flag) fn.compiler_version_flag(compiler.name, compiler.version, name, flag)
) )
def spec_clauses(self, *args, **kwargs): def spec_clauses(
"""Wrap a call to `_spec_clauses()` into a try/except block that self,
raises a comprehensible error message in case of failure. spec: spack.spec.Spec,
*,
body: bool = False,
transitive: bool = True,
expand_hashes: bool = False,
concrete_build_deps=False,
required_from: Optional[str] = None,
) -> List[AspFunction]:
"""Wrap a call to `_spec_clauses()` into a try/except block with better error handling.
Arguments are as for ``_spec_clauses()`` except ``required_from``.
Arguments:
required_from: name of package that caused this call.
""" """
requestor = kwargs.pop("required_from", None)
try: try:
clauses = self._spec_clauses(*args, **kwargs) clauses = self._spec_clauses(
spec,
body=body,
transitive=transitive,
expand_hashes=expand_hashes,
concrete_build_deps=concrete_build_deps,
)
except RuntimeError as exc: except RuntimeError as exc:
msg = str(exc) msg = str(exc)
if requestor: if required_from:
msg += ' [required from package "{0}"]'.format(requestor) msg += f" [required from package '{required_from}']"
raise RuntimeError(msg) raise RuntimeError(msg)
return clauses return clauses
def _spec_clauses( def _spec_clauses(
self, spec, body=False, transitive=True, expand_hashes=False, concrete_build_deps=False self,
): spec: spack.spec.Spec,
*,
body: bool = False,
transitive: bool = True,
expand_hashes: bool = False,
concrete_build_deps: bool = False,
) -> List[AspFunction]:
"""Return a list of clauses for a spec mandates are true. """Return a list of clauses for a spec mandates are true.
Arguments: Arguments:
spec (spack.spec.Spec): the spec to analyze spec: the spec to analyze
body (bool): if True, generate clauses to be used in rule bodies body: if True, generate clauses to be used in rule bodies (final values) instead
(final values) instead of rule heads (setters). of rule heads (setters).
transitive (bool): if False, don't generate clauses from transitive: if False, don't generate clauses from dependencies (default True)
dependencies (default True) expand_hashes: if True, descend into hashes of concrete specs (default False)
expand_hashes (bool): if True, descend into hashes of concrete specs concrete_build_deps: if False, do not include pure build deps of concrete specs
(default False) (as they have no effect on runtime constraints)
concrete_build_deps (bool): if False, do not include pure build deps
of concrete specs (as they have no effect on runtime constraints)
Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates
hashes for the dependency requirements of concrete specs. If ``expand_hashes`` hashes for the dependency requirements of concrete specs. If ``expand_hashes``
@ -1640,7 +1675,7 @@ def _spec_clauses(
""" """
clauses = [] clauses = []
f = _Body if body else _Head f: Union[Type[_Head], Type[_Body]] = _Body if body else _Head
if spec.name: if spec.name:
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name)) clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
@ -1729,8 +1764,9 @@ def _spec_clauses(
# dependencies # dependencies
if spec.concrete: if spec.concrete:
# older specs do not have package hashes, so we have to do this carefully # older specs do not have package hashes, so we have to do this carefully
if getattr(spec, "_package_hash", None): package_hash = getattr(spec, "_package_hash", None)
clauses.append(fn.attr("package_hash", spec.name, spec._package_hash)) if package_hash:
clauses.append(fn.attr("package_hash", spec.name, package_hash))
clauses.append(fn.attr("hash", spec.name, spec.dag_hash())) clauses.append(fn.attr("hash", spec.name, spec.dag_hash()))
edges = spec.edges_from_dependents() edges = spec.edges_from_dependents()
@ -1789,7 +1825,7 @@ def _spec_clauses(
return clauses return clauses
def define_package_versions_and_validate_preferences( def define_package_versions_and_validate_preferences(
self, possible_pkgs, *, require_checksum: bool, allow_deprecated: bool self, possible_pkgs: Set[str], *, require_checksum: bool, allow_deprecated: bool
): ):
"""Declare any versions in specs not declared in packages.""" """Declare any versions in specs not declared in packages."""
packages_yaml = spack.config.get("packages") packages_yaml = spack.config.get("packages")
@ -1822,7 +1858,7 @@ def define_package_versions_and_validate_preferences(
if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]: if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]:
continue continue
version_defs = [] version_defs: List[GitOrStandardVersion] = []
for vstr in packages_yaml[pkg_name]["version"]: for vstr in packages_yaml[pkg_name]["version"]:
v = vn.ver(vstr) v = vn.ver(vstr)
@ -2033,13 +2069,6 @@ def target_defaults(self, specs):
def virtual_providers(self): def virtual_providers(self):
self.gen.h2("Virtual providers") self.gen.h2("Virtual providers")
msg = (
"Internal Error: possible_virtuals is not populated. Please report to the spack"
" maintainers"
)
assert self.possible_virtuals is not None, msg
# what provides what
for vspec in sorted(self.possible_virtuals): for vspec in sorted(self.possible_virtuals):
self.gen.fact(fn.virtual(vspec)) self.gen.fact(fn.virtual(vspec))
self.gen.newline() self.gen.newline()
@ -2236,7 +2265,7 @@ def define_concrete_input_specs(self, specs, possible):
def setup( def setup(
self, self,
specs: Sequence[spack.spec.Spec], specs: List[spack.spec.Spec],
*, *,
reuse: Optional[List[spack.spec.Spec]] = None, reuse: Optional[List[spack.spec.Spec]] = None,
allow_deprecated: bool = False, allow_deprecated: bool = False,