refactor: add type annotations and refactor solver conditions (#42081)
Refactoring `SpackSolverSetup` is a bit easier with type annotations, so I started adding some. This adds annotations for the (many) instance variables on `SpackSolverSetup` as well as a few other places. This also refactors `condition()` to reduce redundancy and to allow `_get_condition_id()` to be called independently of the larger condition function. Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
This commit is contained in:
parent
c7df258ca6
commit
48088ee24a
@ -566,6 +566,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
|
|||||||
provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
|
provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
|
||||||
provided_together: Dict["spack.spec.Spec", List[Set[str]]]
|
provided_together: Dict["spack.spec.Spec", List[Set[str]]]
|
||||||
patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
|
patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
|
||||||
|
variants: Dict[str, Tuple["spack.variant.Variant", "spack.spec.Spec"]]
|
||||||
|
|
||||||
#: By default, packages are not virtual
|
#: By default, packages are not virtual
|
||||||
#: Virtual packages override this attribute
|
#: Virtual packages override this attribute
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
import types
|
import types
|
||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
|
from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import archspec.cpu
|
import archspec.cpu
|
||||||
|
|
||||||
@ -258,7 +258,7 @@ def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunc
|
|||||||
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
|
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
|
||||||
|
|
||||||
|
|
||||||
def _create_counter(specs, tests):
|
def _create_counter(specs: List[spack.spec.Spec], tests: bool):
|
||||||
strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none")
|
strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none")
|
||||||
if strategy == "full":
|
if strategy == "full":
|
||||||
return FullDuplicatesCounter(specs, tests=tests)
|
return FullDuplicatesCounter(specs, tests=tests)
|
||||||
@ -897,35 +897,41 @@ def __iter__(self):
|
|||||||
return iter(self.data)
|
return iter(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
# types for condition caching in solver setup
|
||||||
|
ConditionSpecKey = Tuple[str, Optional[TransformFunction]]
|
||||||
|
ConditionIdFunctionPair = Tuple[int, List[AspFunction]]
|
||||||
|
ConditionSpecCache = Dict[str, Dict[ConditionSpecKey, ConditionIdFunctionPair]]
|
||||||
|
|
||||||
|
|
||||||
class SpackSolverSetup:
|
class SpackSolverSetup:
|
||||||
"""Class to set up and run a Spack concretization solve."""
|
"""Class to set up and run a Spack concretization solve."""
|
||||||
|
|
||||||
def __init__(self, tests=False):
|
def __init__(self, tests: bool = False):
|
||||||
self.gen = None # set by setup()
|
# these are all initialized in setup()
|
||||||
|
self.gen: "ProblemInstanceBuilder" = ProblemInstanceBuilder()
|
||||||
|
self.possible_virtuals: Set[str] = set()
|
||||||
|
|
||||||
self.assumptions = []
|
self.assumptions: List[Tuple["clingo.Symbol", bool]] = [] # type: ignore[name-defined]
|
||||||
self.declared_versions = collections.defaultdict(list)
|
self.declared_versions: Dict[str, List[DeclaredVersion]] = collections.defaultdict(list)
|
||||||
self.possible_versions = collections.defaultdict(set)
|
self.possible_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(set)
|
||||||
self.deprecated_versions = collections.defaultdict(set)
|
self.deprecated_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(
|
||||||
|
set
|
||||||
|
)
|
||||||
|
|
||||||
self.possible_virtuals = None
|
self.possible_compilers: List = []
|
||||||
self.possible_compilers = []
|
self.possible_oses: Set = set()
|
||||||
self.possible_oses = set()
|
self.variant_values_from_specs: Set = set()
|
||||||
self.variant_values_from_specs = set()
|
self.version_constraints: Set = set()
|
||||||
self.version_constraints = set()
|
self.target_constraints: Set = set()
|
||||||
self.target_constraints = set()
|
self.default_targets: List = []
|
||||||
self.default_targets = []
|
self.compiler_version_constraints: Set = set()
|
||||||
self.compiler_version_constraints = set()
|
self.post_facts: List = []
|
||||||
self.post_facts = []
|
|
||||||
|
|
||||||
# (ID, CompilerSpec) -> dictionary of attributes
|
self.reusable_and_possible: ConcreteSpecsByHash = ConcreteSpecsByHash()
|
||||||
self.compiler_info = collections.defaultdict(dict)
|
|
||||||
|
|
||||||
self.reusable_and_possible = ConcreteSpecsByHash()
|
self._id_counter: Iterator[int] = itertools.count()
|
||||||
|
self._trigger_cache: ConditionSpecCache = collections.defaultdict(dict)
|
||||||
self._id_counter = itertools.count()
|
self._effect_cache: ConditionSpecCache = collections.defaultdict(dict)
|
||||||
self._trigger_cache = collections.defaultdict(dict)
|
|
||||||
self._effect_cache = collections.defaultdict(dict)
|
|
||||||
|
|
||||||
# Caches to optimize the setup phase of the solver
|
# Caches to optimize the setup phase of the solver
|
||||||
self.target_specs_cache = None
|
self.target_specs_cache = None
|
||||||
@ -937,8 +943,8 @@ def __init__(self, tests=False):
|
|||||||
self.concretize_everything = True
|
self.concretize_everything = True
|
||||||
|
|
||||||
# Set during the call to setup
|
# Set during the call to setup
|
||||||
self.pkgs = None
|
self.pkgs: Set[str] = set()
|
||||||
self.explicitly_required_namespaces = {}
|
self.explicitly_required_namespaces: Dict[str, str] = {}
|
||||||
|
|
||||||
def pkg_version_rules(self, pkg):
|
def pkg_version_rules(self, pkg):
|
||||||
"""Output declared versions of a package.
|
"""Output declared versions of a package.
|
||||||
@ -1222,6 +1228,38 @@ def variant_rules(self, pkg):
|
|||||||
|
|
||||||
self.gen.newline()
|
self.gen.newline()
|
||||||
|
|
||||||
|
def _get_condition_id(
|
||||||
|
self,
|
||||||
|
named_cond: spack.spec.Spec,
|
||||||
|
cache: ConditionSpecCache,
|
||||||
|
body: bool,
|
||||||
|
transform: Optional[TransformFunction] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Get the id for one half of a condition (either a trigger or an imposed constraint).
|
||||||
|
|
||||||
|
Construct a key from the condition spec and any associated transformation, and
|
||||||
|
cache the ASP functions that they imply. The saved functions will be output
|
||||||
|
later in ``trigger_rules()`` and ``effect_rules()``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The id of the cached trigger or effect.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pkg_cache = cache[named_cond.name]
|
||||||
|
|
||||||
|
named_cond_key = (str(named_cond), transform)
|
||||||
|
result = pkg_cache.get(named_cond_key)
|
||||||
|
if result:
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
cond_id = next(self._id_counter)
|
||||||
|
requirements = self.spec_clauses(named_cond, body=body)
|
||||||
|
if transform:
|
||||||
|
requirements = transform(named_cond, requirements)
|
||||||
|
pkg_cache[named_cond_key] = (cond_id, requirements)
|
||||||
|
|
||||||
|
return cond_id
|
||||||
|
|
||||||
def condition(
|
def condition(
|
||||||
self,
|
self,
|
||||||
required_spec: spack.spec.Spec,
|
required_spec: spack.spec.Spec,
|
||||||
@ -1247,7 +1285,8 @@ def condition(
|
|||||||
"""
|
"""
|
||||||
named_cond = required_spec.copy()
|
named_cond = required_spec.copy()
|
||||||
named_cond.name = named_cond.name or name
|
named_cond.name = named_cond.name or name
|
||||||
assert named_cond.name, "must provide name for anonymous conditions!"
|
if not named_cond.name:
|
||||||
|
raise ValueError(f"Must provide a name for anonymous condition: '{named_cond}'")
|
||||||
|
|
||||||
# Check if we can emit the requirements before updating the condition ID counter.
|
# Check if we can emit the requirements before updating the condition ID counter.
|
||||||
# In this way, if a condition can't be emitted but the exception is handled in the caller,
|
# In this way, if a condition can't be emitted but the exception is handled in the caller,
|
||||||
@ -1257,35 +1296,19 @@ def condition(
|
|||||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
|
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
|
||||||
self.gen.fact(fn.condition_reason(condition_id, msg))
|
self.gen.fact(fn.condition_reason(condition_id, msg))
|
||||||
|
|
||||||
cache = self._trigger_cache[named_cond.name]
|
trigger_id = self._get_condition_id(
|
||||||
|
named_cond, cache=self._trigger_cache, body=True, transform=transform_required
|
||||||
named_cond_key = (str(named_cond), transform_required)
|
)
|
||||||
if named_cond_key not in cache:
|
|
||||||
trigger_id = next(self._id_counter)
|
|
||||||
requirements = self.spec_clauses(named_cond, body=True, required_from=name)
|
|
||||||
|
|
||||||
if transform_required:
|
|
||||||
requirements = transform_required(named_cond, requirements)
|
|
||||||
|
|
||||||
cache[named_cond_key] = (trigger_id, requirements)
|
|
||||||
trigger_id, requirements = cache[named_cond_key]
|
|
||||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
|
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
|
||||||
|
|
||||||
if not imposed_spec:
|
if not imposed_spec:
|
||||||
return condition_id
|
return condition_id
|
||||||
|
|
||||||
cache = self._effect_cache[named_cond.name]
|
effect_id = self._get_condition_id(
|
||||||
imposed_spec_key = (str(imposed_spec), transform_imposed)
|
imposed_spec, cache=self._effect_cache, body=False, transform=transform_imposed
|
||||||
if imposed_spec_key not in cache:
|
)
|
||||||
effect_id = next(self._id_counter)
|
|
||||||
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
|
|
||||||
|
|
||||||
if transform_imposed:
|
|
||||||
requirements = transform_imposed(imposed_spec, requirements)
|
|
||||||
|
|
||||||
cache[imposed_spec_key] = (effect_id, requirements)
|
|
||||||
effect_id, requirements = cache[imposed_spec_key]
|
|
||||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
|
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
|
||||||
|
|
||||||
return condition_id
|
return condition_id
|
||||||
|
|
||||||
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
|
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
|
||||||
@ -1387,23 +1410,13 @@ def virtual_preferences(self, pkg_name, func):
|
|||||||
|
|
||||||
def provider_defaults(self):
|
def provider_defaults(self):
|
||||||
self.gen.h2("Default virtual providers")
|
self.gen.h2("Default virtual providers")
|
||||||
msg = (
|
|
||||||
"Internal Error: possible_virtuals is not populated. Please report to the spack"
|
|
||||||
" maintainers"
|
|
||||||
)
|
|
||||||
assert self.possible_virtuals is not None, msg
|
|
||||||
self.virtual_preferences(
|
self.virtual_preferences(
|
||||||
"all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i))
|
"all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i))
|
||||||
)
|
)
|
||||||
|
|
||||||
def provider_requirements(self):
|
def provider_requirements(self):
|
||||||
self.gen.h2("Requirements on virtual providers")
|
self.gen.h2("Requirements on virtual providers")
|
||||||
msg = (
|
|
||||||
"Internal Error: possible_virtuals is not populated. Please report to the spack"
|
|
||||||
" maintainers"
|
|
||||||
)
|
|
||||||
parser = RequirementParser(spack.config.CONFIG)
|
parser = RequirementParser(spack.config.CONFIG)
|
||||||
assert self.possible_virtuals is not None, msg
|
|
||||||
for virtual_str in sorted(self.possible_virtuals):
|
for virtual_str in sorted(self.possible_virtuals):
|
||||||
rules = parser.rules_from_virtual(virtual_str)
|
rules = parser.rules_from_virtual(virtual_str)
|
||||||
if rules:
|
if rules:
|
||||||
@ -1602,35 +1615,57 @@ def flag_defaults(self):
|
|||||||
fn.compiler_version_flag(compiler.name, compiler.version, name, flag)
|
fn.compiler_version_flag(compiler.name, compiler.version, name, flag)
|
||||||
)
|
)
|
||||||
|
|
||||||
def spec_clauses(self, *args, **kwargs):
|
def spec_clauses(
|
||||||
"""Wrap a call to `_spec_clauses()` into a try/except block that
|
self,
|
||||||
raises a comprehensible error message in case of failure.
|
spec: spack.spec.Spec,
|
||||||
|
*,
|
||||||
|
body: bool = False,
|
||||||
|
transitive: bool = True,
|
||||||
|
expand_hashes: bool = False,
|
||||||
|
concrete_build_deps=False,
|
||||||
|
required_from: Optional[str] = None,
|
||||||
|
) -> List[AspFunction]:
|
||||||
|
"""Wrap a call to `_spec_clauses()` into a try/except block with better error handling.
|
||||||
|
|
||||||
|
Arguments are as for ``_spec_clauses()`` except ``required_from``.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
required_from: name of package that caused this call.
|
||||||
"""
|
"""
|
||||||
requestor = kwargs.pop("required_from", None)
|
|
||||||
try:
|
try:
|
||||||
clauses = self._spec_clauses(*args, **kwargs)
|
clauses = self._spec_clauses(
|
||||||
|
spec,
|
||||||
|
body=body,
|
||||||
|
transitive=transitive,
|
||||||
|
expand_hashes=expand_hashes,
|
||||||
|
concrete_build_deps=concrete_build_deps,
|
||||||
|
)
|
||||||
except RuntimeError as exc:
|
except RuntimeError as exc:
|
||||||
msg = str(exc)
|
msg = str(exc)
|
||||||
if requestor:
|
if required_from:
|
||||||
msg += ' [required from package "{0}"]'.format(requestor)
|
msg += f" [required from package '{required_from}']"
|
||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
return clauses
|
return clauses
|
||||||
|
|
||||||
def _spec_clauses(
|
def _spec_clauses(
|
||||||
self, spec, body=False, transitive=True, expand_hashes=False, concrete_build_deps=False
|
self,
|
||||||
):
|
spec: spack.spec.Spec,
|
||||||
|
*,
|
||||||
|
body: bool = False,
|
||||||
|
transitive: bool = True,
|
||||||
|
expand_hashes: bool = False,
|
||||||
|
concrete_build_deps: bool = False,
|
||||||
|
) -> List[AspFunction]:
|
||||||
"""Return a list of clauses for a spec mandates are true.
|
"""Return a list of clauses for a spec mandates are true.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
spec (spack.spec.Spec): the spec to analyze
|
spec: the spec to analyze
|
||||||
body (bool): if True, generate clauses to be used in rule bodies
|
body: if True, generate clauses to be used in rule bodies (final values) instead
|
||||||
(final values) instead of rule heads (setters).
|
of rule heads (setters).
|
||||||
transitive (bool): if False, don't generate clauses from
|
transitive: if False, don't generate clauses from dependencies (default True)
|
||||||
dependencies (default True)
|
expand_hashes: if True, descend into hashes of concrete specs (default False)
|
||||||
expand_hashes (bool): if True, descend into hashes of concrete specs
|
concrete_build_deps: if False, do not include pure build deps of concrete specs
|
||||||
(default False)
|
(as they have no effect on runtime constraints)
|
||||||
concrete_build_deps (bool): if False, do not include pure build deps
|
|
||||||
of concrete specs (as they have no effect on runtime constraints)
|
|
||||||
|
|
||||||
Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates
|
Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates
|
||||||
hashes for the dependency requirements of concrete specs. If ``expand_hashes``
|
hashes for the dependency requirements of concrete specs. If ``expand_hashes``
|
||||||
@ -1640,7 +1675,7 @@ def _spec_clauses(
|
|||||||
"""
|
"""
|
||||||
clauses = []
|
clauses = []
|
||||||
|
|
||||||
f = _Body if body else _Head
|
f: Union[Type[_Head], Type[_Body]] = _Body if body else _Head
|
||||||
|
|
||||||
if spec.name:
|
if spec.name:
|
||||||
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
|
clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
|
||||||
@ -1729,8 +1764,9 @@ def _spec_clauses(
|
|||||||
# dependencies
|
# dependencies
|
||||||
if spec.concrete:
|
if spec.concrete:
|
||||||
# older specs do not have package hashes, so we have to do this carefully
|
# older specs do not have package hashes, so we have to do this carefully
|
||||||
if getattr(spec, "_package_hash", None):
|
package_hash = getattr(spec, "_package_hash", None)
|
||||||
clauses.append(fn.attr("package_hash", spec.name, spec._package_hash))
|
if package_hash:
|
||||||
|
clauses.append(fn.attr("package_hash", spec.name, package_hash))
|
||||||
clauses.append(fn.attr("hash", spec.name, spec.dag_hash()))
|
clauses.append(fn.attr("hash", spec.name, spec.dag_hash()))
|
||||||
|
|
||||||
edges = spec.edges_from_dependents()
|
edges = spec.edges_from_dependents()
|
||||||
@ -1789,7 +1825,7 @@ def _spec_clauses(
|
|||||||
return clauses
|
return clauses
|
||||||
|
|
||||||
def define_package_versions_and_validate_preferences(
|
def define_package_versions_and_validate_preferences(
|
||||||
self, possible_pkgs, *, require_checksum: bool, allow_deprecated: bool
|
self, possible_pkgs: Set[str], *, require_checksum: bool, allow_deprecated: bool
|
||||||
):
|
):
|
||||||
"""Declare any versions in specs not declared in packages."""
|
"""Declare any versions in specs not declared in packages."""
|
||||||
packages_yaml = spack.config.get("packages")
|
packages_yaml = spack.config.get("packages")
|
||||||
@ -1822,7 +1858,7 @@ def define_package_versions_and_validate_preferences(
|
|||||||
if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]:
|
if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
version_defs = []
|
version_defs: List[GitOrStandardVersion] = []
|
||||||
|
|
||||||
for vstr in packages_yaml[pkg_name]["version"]:
|
for vstr in packages_yaml[pkg_name]["version"]:
|
||||||
v = vn.ver(vstr)
|
v = vn.ver(vstr)
|
||||||
@ -2033,13 +2069,6 @@ def target_defaults(self, specs):
|
|||||||
|
|
||||||
def virtual_providers(self):
|
def virtual_providers(self):
|
||||||
self.gen.h2("Virtual providers")
|
self.gen.h2("Virtual providers")
|
||||||
msg = (
|
|
||||||
"Internal Error: possible_virtuals is not populated. Please report to the spack"
|
|
||||||
" maintainers"
|
|
||||||
)
|
|
||||||
assert self.possible_virtuals is not None, msg
|
|
||||||
|
|
||||||
# what provides what
|
|
||||||
for vspec in sorted(self.possible_virtuals):
|
for vspec in sorted(self.possible_virtuals):
|
||||||
self.gen.fact(fn.virtual(vspec))
|
self.gen.fact(fn.virtual(vspec))
|
||||||
self.gen.newline()
|
self.gen.newline()
|
||||||
@ -2236,7 +2265,7 @@ def define_concrete_input_specs(self, specs, possible):
|
|||||||
|
|
||||||
def setup(
|
def setup(
|
||||||
self,
|
self,
|
||||||
specs: Sequence[spack.spec.Spec],
|
specs: List[spack.spec.Spec],
|
||||||
*,
|
*,
|
||||||
reuse: Optional[List[spack.spec.Spec]] = None,
|
reuse: Optional[List[spack.spec.Spec]] = None,
|
||||||
allow_deprecated: bool = False,
|
allow_deprecated: bool = False,
|
||||||
|
Loading…
Reference in New Issue
Block a user