refactor: add type annotations and refactor solver conditions (#42081)

Refactoring `SpackSolverSetup` is a bit easier with type annotations, so I started
adding some. This adds annotations for the (many) instance variables on
`SpackSolverSetup` as well as a few other places.

This also refactors `condition()` to reduce redundancy and to allow
`_get_condition_id()` to be called independently of the larger condition
function.


Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
This commit is contained in:
Todd Gamblin 2024-02-26 14:26:01 -08:00 committed by GitHub
parent c7df258ca6
commit 48088ee24a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 91 deletions

View File

@ -566,6 +566,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
provided_together: Dict["spack.spec.Spec", List[Set[str]]]
patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
variants: Dict[str, Tuple["spack.variant.Variant", "spack.spec.Spec"]]
#: By default, packages are not virtual
#: Virtual packages override this attribute

View File

@ -15,7 +15,7 @@
import types
import typing
import warnings
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Type, Union
import archspec.cpu
@ -258,7 +258,7 @@ def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunc
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
def _create_counter(specs, tests):
def _create_counter(specs: List[spack.spec.Spec], tests: bool):
strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none")
if strategy == "full":
return FullDuplicatesCounter(specs, tests=tests)
@ -897,35 +897,41 @@ def __iter__(self):
return iter(self.data)
# types for condition caching in solver setup
ConditionSpecKey = Tuple[str, Optional[TransformFunction]]
ConditionIdFunctionPair = Tuple[int, List[AspFunction]]
ConditionSpecCache = Dict[str, Dict[ConditionSpecKey, ConditionIdFunctionPair]]
class SpackSolverSetup:
"""Class to set up and run a Spack concretization solve."""
def __init__(self, tests=False):
self.gen = None # set by setup()
def __init__(self, tests: bool = False):
# these are all initialized in setup()
self.gen: "ProblemInstanceBuilder" = ProblemInstanceBuilder()
self.possible_virtuals: Set[str] = set()
self.assumptions = []
self.declared_versions = collections.defaultdict(list)
self.possible_versions = collections.defaultdict(set)
self.deprecated_versions = collections.defaultdict(set)
self.assumptions: List[Tuple["clingo.Symbol", bool]] = [] # type: ignore[name-defined]
self.declared_versions: Dict[str, List[DeclaredVersion]] = collections.defaultdict(list)
self.possible_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(set)
self.deprecated_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(
set
)
self.possible_virtuals = None
self.possible_compilers = []
self.possible_oses = set()
self.variant_values_from_specs = set()
self.version_constraints = set()
self.target_constraints = set()
self.default_targets = []
self.compiler_version_constraints = set()
self.post_facts = []
self.possible_compilers: List = []
self.possible_oses: Set = set()
self.variant_values_from_specs: Set = set()
self.version_constraints: Set = set()
self.target_constraints: Set = set()
self.default_targets: List = []
self.compiler_version_constraints: Set = set()
self.post_facts: List = []
# (ID, CompilerSpec) -> dictionary of attributes
self.compiler_info = collections.defaultdict(dict)
self.reusable_and_possible: ConcreteSpecsByHash = ConcreteSpecsByHash()
self.reusable_and_possible = ConcreteSpecsByHash()
self._id_counter = itertools.count()
self._trigger_cache = collections.defaultdict(dict)
self._effect_cache = collections.defaultdict(dict)
self._id_counter: Iterator[int] = itertools.count()
self._trigger_cache: ConditionSpecCache = collections.defaultdict(dict)
self._effect_cache: ConditionSpecCache = collections.defaultdict(dict)
# Caches to optimize the setup phase of the solver
self.target_specs_cache = None
@ -937,8 +943,8 @@ def __init__(self, tests=False):
self.concretize_everything = True
# Set during the call to setup
self.pkgs = None
self.explicitly_required_namespaces = {}
self.pkgs: Set[str] = set()
self.explicitly_required_namespaces: Dict[str, str] = {}
def pkg_version_rules(self, pkg):
"""Output declared versions of a package.
@ -1222,6 +1228,38 @@ def variant_rules(self, pkg):
self.gen.newline()
def _get_condition_id(
self,
named_cond: spack.spec.Spec,
cache: ConditionSpecCache,
body: bool,
transform: Optional[TransformFunction] = None,
) -> int:
"""Get the id for one half of a condition (either a trigger or an imposed constraint).
Construct a key from the condition spec and any associated transformation, and
cache the ASP functions that they imply. The saved functions will be output
later in ``trigger_rules()`` and ``effect_rules()``.
Returns:
The id of the cached trigger or effect.
"""
pkg_cache = cache[named_cond.name]
named_cond_key = (str(named_cond), transform)
result = pkg_cache.get(named_cond_key)
if result:
return result[0]
cond_id = next(self._id_counter)
requirements = self.spec_clauses(named_cond, body=body)
if transform:
requirements = transform(named_cond, requirements)
pkg_cache[named_cond_key] = (cond_id, requirements)
return cond_id
def condition(
self,
required_spec: spack.spec.Spec,
@ -1247,7 +1285,8 @@ def condition(
"""
named_cond = required_spec.copy()
named_cond.name = named_cond.name or name
assert named_cond.name, "must provide name for anonymous conditions!"
if not named_cond.name:
raise ValueError(f"Must provide a name for anonymous condition: '{named_cond}'")
# Check if we can emit the requirements before updating the condition ID counter.
# In this way, if a condition can't be emitted but the exception is handled in the caller,
@ -1257,35 +1296,19 @@ def condition(
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
self.gen.fact(fn.condition_reason(condition_id, msg))
cache = self._trigger_cache[named_cond.name]
named_cond_key = (str(named_cond), transform_required)
if named_cond_key not in cache:
trigger_id = next(self._id_counter)
requirements = self.spec_clauses(named_cond, body=True, required_from=name)
if transform_required:
requirements = transform_required(named_cond, requirements)
cache[named_cond_key] = (trigger_id, requirements)
trigger_id, requirements = cache[named_cond_key]
trigger_id = self._get_condition_id(
named_cond, cache=self._trigger_cache, body=True, transform=transform_required
)
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
if not imposed_spec:
return condition_id
cache = self._effect_cache[named_cond.name]
imposed_spec_key = (str(imposed_spec), transform_imposed)
if imposed_spec_key not in cache:
effect_id = next(self._id_counter)
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
if transform_imposed:
requirements = transform_imposed(imposed_spec, requirements)
cache[imposed_spec_key] = (effect_id, requirements)
effect_id, requirements = cache[imposed_spec_key]
effect_id = self._get_condition_id(
imposed_spec, cache=self._effect_cache, body=False, transform=transform_imposed
)
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
return condition_id
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
@ -1387,23 +1410,13 @@ def virtual_preferences(self, pkg_name, func):
def provider_defaults(self):
self.gen.h2("Default virtual providers")
msg = (
"Internal Error: possible_virtuals is not populated. Please report to the spack"
" maintainers"
)
assert self.possible_virtuals is not None, msg
self.virtual_preferences(
"all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i))
)
def provider_requirements(self):
self.gen.h2("Requirements on virtual providers")
msg = (
"Internal Error: possible_virtuals is not populated. Please report to the spack"
" maintainers"
)
parser = RequirementParser(spack.config.CONFIG)
assert self.possible_virtuals is not None, msg
for virtual_str in sorted(self.possible_virtuals):
rules = parser.rules_from_virtual(virtual_str)
if rules:
@ -1602,35 +1615,57 @@ def flag_defaults(self):
fn.compiler_version_flag(compiler.name, compiler.version, name, flag)
)
def spec_clauses(self, *args, **kwargs):
"""Wrap a call to `_spec_clauses()` into a try/except block that
raises a comprehensible error message in case of failure.
def spec_clauses(
self,
spec: spack.spec.Spec,
*,
body: bool = False,
transitive: bool = True,
expand_hashes: bool = False,
concrete_build_deps=False,
required_from: Optional[str] = None,
) -> List[AspFunction]:
"""Wrap a call to `_spec_clauses()` into a try/except block with better error handling.
Arguments are as for ``_spec_clauses()`` except ``required_from``.
Arguments:
required_from: name of package that caused this call.
"""
requestor = kwargs.pop("required_from", None)
try:
clauses = self._spec_clauses(*args, **kwargs)
clauses = self._spec_clauses(
spec,
body=body,
transitive=transitive,
expand_hashes=expand_hashes,
concrete_build_deps=concrete_build_deps,
)
except RuntimeError as exc:
msg = str(exc)
if requestor:
msg += ' [required from package "{0}"]'.format(requestor)
if required_from:
msg += f" [required from package '{required_from}']"
raise RuntimeError(msg)
return clauses
def _spec_clauses(
self, spec, body=False, transitive=True, expand_hashes=False, concrete_build_deps=False
):
self,
spec: spack.spec.Spec,
*,
body: bool = False,
transitive: bool = True,
expand_hashes: bool = False,
concrete_build_deps: bool = False,
) -> List[AspFunction]:
"""Return a list of clauses for a spec mandates are true.
Arguments:
spec (spack.spec.Spec): the spec to analyze
body (bool): if True, generate clauses to be used in rule bodies
(final values) instead of rule heads (setters).
transitive (bool): if False, don't generate clauses from
dependencies (default True)
expand_hashes (bool): if True, descend into hashes of concrete specs
(default False)
concrete_build_deps (bool): if False, do not include pure build deps
of concrete specs (as they have no effect on runtime constraints)
spec: the spec to analyze
body: if True, generate clauses to be used in rule bodies (final values) instead
of rule heads (setters).
transitive: if False, don't generate clauses from dependencies (default True)
expand_hashes: if True, descend into hashes of concrete specs (default False)
concrete_build_deps: if False, do not include pure build deps of concrete specs
(as they have no effect on runtime constraints)
Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates
hashes for the dependency requirements of concrete specs. If ``expand_hashes``
@ -1640,7 +1675,7 @@ def _spec_clauses(
"""
clauses = []
f = _Body if body else _Head
f: Union[Type[_Head], Type[_Body]] = _Body if body else _Head
if spec.name:
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
@ -1729,8 +1764,9 @@ def _spec_clauses(
# dependencies
if spec.concrete:
# older specs do not have package hashes, so we have to do this carefully
if getattr(spec, "_package_hash", None):
clauses.append(fn.attr("package_hash", spec.name, spec._package_hash))
package_hash = getattr(spec, "_package_hash", None)
if package_hash:
clauses.append(fn.attr("package_hash", spec.name, package_hash))
clauses.append(fn.attr("hash", spec.name, spec.dag_hash()))
edges = spec.edges_from_dependents()
@ -1789,7 +1825,7 @@ def _spec_clauses(
return clauses
def define_package_versions_and_validate_preferences(
self, possible_pkgs, *, require_checksum: bool, allow_deprecated: bool
self, possible_pkgs: Set[str], *, require_checksum: bool, allow_deprecated: bool
):
"""Declare any versions in specs not declared in packages."""
packages_yaml = spack.config.get("packages")
@ -1822,7 +1858,7 @@ def define_package_versions_and_validate_preferences(
if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]:
continue
version_defs = []
version_defs: List[GitOrStandardVersion] = []
for vstr in packages_yaml[pkg_name]["version"]:
v = vn.ver(vstr)
@ -2033,13 +2069,6 @@ def target_defaults(self, specs):
def virtual_providers(self):
self.gen.h2("Virtual providers")
msg = (
"Internal Error: possible_virtuals is not populated. Please report to the spack"
" maintainers"
)
assert self.possible_virtuals is not None, msg
# what provides what
for vspec in sorted(self.possible_virtuals):
self.gen.fact(fn.virtual(vspec))
self.gen.newline()
@ -2236,7 +2265,7 @@ def define_concrete_input_specs(self, specs, possible):
def setup(
self,
specs: Sequence[spack.spec.Spec],
specs: List[spack.spec.Spec],
*,
reuse: Optional[List[spack.spec.Spec]] = None,
allow_deprecated: bool = False,