refactor: add type annotations and refactor solver conditions (#42081)
Refactoring `SpackSolverSetup` is a bit easier with type annotations, so I started adding some. This adds annotations for the (many) instance variables on `SpackSolverSetup` as well as a few other places. This also refactors `condition()` to reduce redundancy and to allow `_get_condition_id()` to be called independently of the larger condition function. Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
This commit is contained in:
parent
c7df258ca6
commit
48088ee24a
@ -566,6 +566,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
|
||||
provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
|
||||
provided_together: Dict["spack.spec.Spec", List[Set[str]]]
|
||||
patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
|
||||
variants: Dict[str, Tuple["spack.variant.Variant", "spack.spec.Spec"]]
|
||||
|
||||
#: By default, packages are not virtual
|
||||
#: Virtual packages override this attribute
|
||||
|
@ -15,7 +15,7 @@
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import archspec.cpu
|
||||
|
||||
@ -258,7 +258,7 @@ def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunc
|
||||
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
|
||||
|
||||
|
||||
def _create_counter(specs, tests):
|
||||
def _create_counter(specs: List[spack.spec.Spec], tests: bool):
|
||||
strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none")
|
||||
if strategy == "full":
|
||||
return FullDuplicatesCounter(specs, tests=tests)
|
||||
@ -897,35 +897,41 @@ def __iter__(self):
|
||||
return iter(self.data)
|
||||
|
||||
|
||||
# types for condition caching in solver setup
|
||||
ConditionSpecKey = Tuple[str, Optional[TransformFunction]]
|
||||
ConditionIdFunctionPair = Tuple[int, List[AspFunction]]
|
||||
ConditionSpecCache = Dict[str, Dict[ConditionSpecKey, ConditionIdFunctionPair]]
|
||||
|
||||
|
||||
class SpackSolverSetup:
|
||||
"""Class to set up and run a Spack concretization solve."""
|
||||
|
||||
def __init__(self, tests=False):
|
||||
self.gen = None # set by setup()
|
||||
def __init__(self, tests: bool = False):
|
||||
# these are all initialized in setup()
|
||||
self.gen: "ProblemInstanceBuilder" = ProblemInstanceBuilder()
|
||||
self.possible_virtuals: Set[str] = set()
|
||||
|
||||
self.assumptions = []
|
||||
self.declared_versions = collections.defaultdict(list)
|
||||
self.possible_versions = collections.defaultdict(set)
|
||||
self.deprecated_versions = collections.defaultdict(set)
|
||||
self.assumptions: List[Tuple["clingo.Symbol", bool]] = [] # type: ignore[name-defined]
|
||||
self.declared_versions: Dict[str, List[DeclaredVersion]] = collections.defaultdict(list)
|
||||
self.possible_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(set)
|
||||
self.deprecated_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(
|
||||
set
|
||||
)
|
||||
|
||||
self.possible_virtuals = None
|
||||
self.possible_compilers = []
|
||||
self.possible_oses = set()
|
||||
self.variant_values_from_specs = set()
|
||||
self.version_constraints = set()
|
||||
self.target_constraints = set()
|
||||
self.default_targets = []
|
||||
self.compiler_version_constraints = set()
|
||||
self.post_facts = []
|
||||
self.possible_compilers: List = []
|
||||
self.possible_oses: Set = set()
|
||||
self.variant_values_from_specs: Set = set()
|
||||
self.version_constraints: Set = set()
|
||||
self.target_constraints: Set = set()
|
||||
self.default_targets: List = []
|
||||
self.compiler_version_constraints: Set = set()
|
||||
self.post_facts: List = []
|
||||
|
||||
# (ID, CompilerSpec) -> dictionary of attributes
|
||||
self.compiler_info = collections.defaultdict(dict)
|
||||
self.reusable_and_possible: ConcreteSpecsByHash = ConcreteSpecsByHash()
|
||||
|
||||
self.reusable_and_possible = ConcreteSpecsByHash()
|
||||
|
||||
self._id_counter = itertools.count()
|
||||
self._trigger_cache = collections.defaultdict(dict)
|
||||
self._effect_cache = collections.defaultdict(dict)
|
||||
self._id_counter: Iterator[int] = itertools.count()
|
||||
self._trigger_cache: ConditionSpecCache = collections.defaultdict(dict)
|
||||
self._effect_cache: ConditionSpecCache = collections.defaultdict(dict)
|
||||
|
||||
# Caches to optimize the setup phase of the solver
|
||||
self.target_specs_cache = None
|
||||
@ -937,8 +943,8 @@ def __init__(self, tests=False):
|
||||
self.concretize_everything = True
|
||||
|
||||
# Set during the call to setup
|
||||
self.pkgs = None
|
||||
self.explicitly_required_namespaces = {}
|
||||
self.pkgs: Set[str] = set()
|
||||
self.explicitly_required_namespaces: Dict[str, str] = {}
|
||||
|
||||
def pkg_version_rules(self, pkg):
|
||||
"""Output declared versions of a package.
|
||||
@ -1222,6 +1228,38 @@ def variant_rules(self, pkg):
|
||||
|
||||
self.gen.newline()
|
||||
|
||||
def _get_condition_id(
|
||||
self,
|
||||
named_cond: spack.spec.Spec,
|
||||
cache: ConditionSpecCache,
|
||||
body: bool,
|
||||
transform: Optional[TransformFunction] = None,
|
||||
) -> int:
|
||||
"""Get the id for one half of a condition (either a trigger or an imposed constraint).
|
||||
|
||||
Construct a key from the condition spec and any associated transformation, and
|
||||
cache the ASP functions that they imply. The saved functions will be output
|
||||
later in ``trigger_rules()`` and ``effect_rules()``.
|
||||
|
||||
Returns:
|
||||
The id of the cached trigger or effect.
|
||||
|
||||
"""
|
||||
pkg_cache = cache[named_cond.name]
|
||||
|
||||
named_cond_key = (str(named_cond), transform)
|
||||
result = pkg_cache.get(named_cond_key)
|
||||
if result:
|
||||
return result[0]
|
||||
|
||||
cond_id = next(self._id_counter)
|
||||
requirements = self.spec_clauses(named_cond, body=body)
|
||||
if transform:
|
||||
requirements = transform(named_cond, requirements)
|
||||
pkg_cache[named_cond_key] = (cond_id, requirements)
|
||||
|
||||
return cond_id
|
||||
|
||||
def condition(
|
||||
self,
|
||||
required_spec: spack.spec.Spec,
|
||||
@ -1247,7 +1285,8 @@ def condition(
|
||||
"""
|
||||
named_cond = required_spec.copy()
|
||||
named_cond.name = named_cond.name or name
|
||||
assert named_cond.name, "must provide name for anonymous conditions!"
|
||||
if not named_cond.name:
|
||||
raise ValueError(f"Must provide a name for anonymous condition: '{named_cond}'")
|
||||
|
||||
# Check if we can emit the requirements before updating the condition ID counter.
|
||||
# In this way, if a condition can't be emitted but the exception is handled in the caller,
|
||||
@ -1257,35 +1296,19 @@ def condition(
|
||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
|
||||
self.gen.fact(fn.condition_reason(condition_id, msg))
|
||||
|
||||
cache = self._trigger_cache[named_cond.name]
|
||||
|
||||
named_cond_key = (str(named_cond), transform_required)
|
||||
if named_cond_key not in cache:
|
||||
trigger_id = next(self._id_counter)
|
||||
requirements = self.spec_clauses(named_cond, body=True, required_from=name)
|
||||
|
||||
if transform_required:
|
||||
requirements = transform_required(named_cond, requirements)
|
||||
|
||||
cache[named_cond_key] = (trigger_id, requirements)
|
||||
trigger_id, requirements = cache[named_cond_key]
|
||||
trigger_id = self._get_condition_id(
|
||||
named_cond, cache=self._trigger_cache, body=True, transform=transform_required
|
||||
)
|
||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
|
||||
|
||||
if not imposed_spec:
|
||||
return condition_id
|
||||
|
||||
cache = self._effect_cache[named_cond.name]
|
||||
imposed_spec_key = (str(imposed_spec), transform_imposed)
|
||||
if imposed_spec_key not in cache:
|
||||
effect_id = next(self._id_counter)
|
||||
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
|
||||
|
||||
if transform_imposed:
|
||||
requirements = transform_imposed(imposed_spec, requirements)
|
||||
|
||||
cache[imposed_spec_key] = (effect_id, requirements)
|
||||
effect_id, requirements = cache[imposed_spec_key]
|
||||
effect_id = self._get_condition_id(
|
||||
imposed_spec, cache=self._effect_cache, body=False, transform=transform_imposed
|
||||
)
|
||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
|
||||
|
||||
return condition_id
|
||||
|
||||
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
|
||||
@ -1387,23 +1410,13 @@ def virtual_preferences(self, pkg_name, func):
|
||||
|
||||
def provider_defaults(self):
|
||||
self.gen.h2("Default virtual providers")
|
||||
msg = (
|
||||
"Internal Error: possible_virtuals is not populated. Please report to the spack"
|
||||
" maintainers"
|
||||
)
|
||||
assert self.possible_virtuals is not None, msg
|
||||
self.virtual_preferences(
|
||||
"all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i))
|
||||
)
|
||||
|
||||
def provider_requirements(self):
|
||||
self.gen.h2("Requirements on virtual providers")
|
||||
msg = (
|
||||
"Internal Error: possible_virtuals is not populated. Please report to the spack"
|
||||
" maintainers"
|
||||
)
|
||||
parser = RequirementParser(spack.config.CONFIG)
|
||||
assert self.possible_virtuals is not None, msg
|
||||
for virtual_str in sorted(self.possible_virtuals):
|
||||
rules = parser.rules_from_virtual(virtual_str)
|
||||
if rules:
|
||||
@ -1602,35 +1615,57 @@ def flag_defaults(self):
|
||||
fn.compiler_version_flag(compiler.name, compiler.version, name, flag)
|
||||
)
|
||||
|
||||
def spec_clauses(self, *args, **kwargs):
|
||||
"""Wrap a call to `_spec_clauses()` into a try/except block that
|
||||
raises a comprehensible error message in case of failure.
|
||||
def spec_clauses(
|
||||
self,
|
||||
spec: spack.spec.Spec,
|
||||
*,
|
||||
body: bool = False,
|
||||
transitive: bool = True,
|
||||
expand_hashes: bool = False,
|
||||
concrete_build_deps=False,
|
||||
required_from: Optional[str] = None,
|
||||
) -> List[AspFunction]:
|
||||
"""Wrap a call to `_spec_clauses()` into a try/except block with better error handling.
|
||||
|
||||
Arguments are as for ``_spec_clauses()`` except ``required_from``.
|
||||
|
||||
Arguments:
|
||||
required_from: name of package that caused this call.
|
||||
"""
|
||||
requestor = kwargs.pop("required_from", None)
|
||||
try:
|
||||
clauses = self._spec_clauses(*args, **kwargs)
|
||||
clauses = self._spec_clauses(
|
||||
spec,
|
||||
body=body,
|
||||
transitive=transitive,
|
||||
expand_hashes=expand_hashes,
|
||||
concrete_build_deps=concrete_build_deps,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
msg = str(exc)
|
||||
if requestor:
|
||||
msg += ' [required from package "{0}"]'.format(requestor)
|
||||
if required_from:
|
||||
msg += f" [required from package '{required_from}']"
|
||||
raise RuntimeError(msg)
|
||||
return clauses
|
||||
|
||||
def _spec_clauses(
|
||||
self, spec, body=False, transitive=True, expand_hashes=False, concrete_build_deps=False
|
||||
):
|
||||
self,
|
||||
spec: spack.spec.Spec,
|
||||
*,
|
||||
body: bool = False,
|
||||
transitive: bool = True,
|
||||
expand_hashes: bool = False,
|
||||
concrete_build_deps: bool = False,
|
||||
) -> List[AspFunction]:
|
||||
"""Return a list of clauses for a spec mandates are true.
|
||||
|
||||
Arguments:
|
||||
spec (spack.spec.Spec): the spec to analyze
|
||||
body (bool): if True, generate clauses to be used in rule bodies
|
||||
(final values) instead of rule heads (setters).
|
||||
transitive (bool): if False, don't generate clauses from
|
||||
dependencies (default True)
|
||||
expand_hashes (bool): if True, descend into hashes of concrete specs
|
||||
(default False)
|
||||
concrete_build_deps (bool): if False, do not include pure build deps
|
||||
of concrete specs (as they have no effect on runtime constraints)
|
||||
spec: the spec to analyze
|
||||
body: if True, generate clauses to be used in rule bodies (final values) instead
|
||||
of rule heads (setters).
|
||||
transitive: if False, don't generate clauses from dependencies (default True)
|
||||
expand_hashes: if True, descend into hashes of concrete specs (default False)
|
||||
concrete_build_deps: if False, do not include pure build deps of concrete specs
|
||||
(as they have no effect on runtime constraints)
|
||||
|
||||
Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates
|
||||
hashes for the dependency requirements of concrete specs. If ``expand_hashes``
|
||||
@ -1640,7 +1675,7 @@ def _spec_clauses(
|
||||
"""
|
||||
clauses = []
|
||||
|
||||
f = _Body if body else _Head
|
||||
f: Union[Type[_Head], Type[_Body]] = _Body if body else _Head
|
||||
|
||||
if spec.name:
|
||||
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
|
||||
@ -1729,8 +1764,9 @@ def _spec_clauses(
|
||||
# dependencies
|
||||
if spec.concrete:
|
||||
# older specs do not have package hashes, so we have to do this carefully
|
||||
if getattr(spec, "_package_hash", None):
|
||||
clauses.append(fn.attr("package_hash", spec.name, spec._package_hash))
|
||||
package_hash = getattr(spec, "_package_hash", None)
|
||||
if package_hash:
|
||||
clauses.append(fn.attr("package_hash", spec.name, package_hash))
|
||||
clauses.append(fn.attr("hash", spec.name, spec.dag_hash()))
|
||||
|
||||
edges = spec.edges_from_dependents()
|
||||
@ -1789,7 +1825,7 @@ def _spec_clauses(
|
||||
return clauses
|
||||
|
||||
def define_package_versions_and_validate_preferences(
|
||||
self, possible_pkgs, *, require_checksum: bool, allow_deprecated: bool
|
||||
self, possible_pkgs: Set[str], *, require_checksum: bool, allow_deprecated: bool
|
||||
):
|
||||
"""Declare any versions in specs not declared in packages."""
|
||||
packages_yaml = spack.config.get("packages")
|
||||
@ -1822,7 +1858,7 @@ def define_package_versions_and_validate_preferences(
|
||||
if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]:
|
||||
continue
|
||||
|
||||
version_defs = []
|
||||
version_defs: List[GitOrStandardVersion] = []
|
||||
|
||||
for vstr in packages_yaml[pkg_name]["version"]:
|
||||
v = vn.ver(vstr)
|
||||
@ -2033,13 +2069,6 @@ def target_defaults(self, specs):
|
||||
|
||||
def virtual_providers(self):
|
||||
self.gen.h2("Virtual providers")
|
||||
msg = (
|
||||
"Internal Error: possible_virtuals is not populated. Please report to the spack"
|
||||
" maintainers"
|
||||
)
|
||||
assert self.possible_virtuals is not None, msg
|
||||
|
||||
# what provides what
|
||||
for vspec in sorted(self.possible_virtuals):
|
||||
self.gen.fact(fn.virtual(vspec))
|
||||
self.gen.newline()
|
||||
@ -2236,7 +2265,7 @@ def define_concrete_input_specs(self, specs, possible):
|
||||
|
||||
def setup(
|
||||
self,
|
||||
specs: Sequence[spack.spec.Spec],
|
||||
specs: List[spack.spec.Spec],
|
||||
*,
|
||||
reuse: Optional[List[spack.spec.Spec]] = None,
|
||||
allow_deprecated: bool = False,
|
||||
|
Loading…
Reference in New Issue
Block a user