refactor: add type annotations and refactor solver conditions (#42081)
Refactoring `SpackSolverSetup` is a bit easier with type annotations, so I started adding some. This adds annotations for the (many) instance variables on `SpackSolverSetup` as well as a few other places. This also refactors `condition()` to reduce redundancy and to allow `_get_condition_id()` to be called independently of the larger condition function. Co-authored-by: Massimiliano Culpo <massimiliano.culpo@gmail.com>
This commit is contained in:
		@@ -566,6 +566,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
 | 
			
		||||
    provided: Dict["spack.spec.Spec", Set["spack.spec.Spec"]]
 | 
			
		||||
    provided_together: Dict["spack.spec.Spec", List[Set[str]]]
 | 
			
		||||
    patches: Dict["spack.spec.Spec", List["spack.patch.Patch"]]
 | 
			
		||||
    variants: Dict[str, Tuple["spack.variant.Variant", "spack.spec.Spec"]]
 | 
			
		||||
 | 
			
		||||
    #: By default, packages are not virtual
 | 
			
		||||
    #: Virtual packages override this attribute
 | 
			
		||||
 
 | 
			
		||||
@@ -15,7 +15,7 @@
 | 
			
		||||
import types
 | 
			
		||||
import typing
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
 | 
			
		||||
from typing import Callable, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Type, Union
 | 
			
		||||
 | 
			
		||||
import archspec.cpu
 | 
			
		||||
 | 
			
		||||
@@ -258,7 +258,7 @@ def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunc
 | 
			
		||||
    return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_counter(specs, tests):
 | 
			
		||||
def _create_counter(specs: List[spack.spec.Spec], tests: bool):
 | 
			
		||||
    strategy = spack.config.CONFIG.get("concretizer:duplicates:strategy", "none")
 | 
			
		||||
    if strategy == "full":
 | 
			
		||||
        return FullDuplicatesCounter(specs, tests=tests)
 | 
			
		||||
@@ -897,35 +897,41 @@ def __iter__(self):
 | 
			
		||||
        return iter(self.data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# types for condition caching in solver setup
 | 
			
		||||
ConditionSpecKey = Tuple[str, Optional[TransformFunction]]
 | 
			
		||||
ConditionIdFunctionPair = Tuple[int, List[AspFunction]]
 | 
			
		||||
ConditionSpecCache = Dict[str, Dict[ConditionSpecKey, ConditionIdFunctionPair]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SpackSolverSetup:
 | 
			
		||||
    """Class to set up and run a Spack concretization solve."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, tests=False):
 | 
			
		||||
        self.gen = None  # set by setup()
 | 
			
		||||
    def __init__(self, tests: bool = False):
 | 
			
		||||
        # these are all initialized in setup()
 | 
			
		||||
        self.gen: "ProblemInstanceBuilder" = ProblemInstanceBuilder()
 | 
			
		||||
        self.possible_virtuals: Set[str] = set()
 | 
			
		||||
 | 
			
		||||
        self.assumptions = []
 | 
			
		||||
        self.declared_versions = collections.defaultdict(list)
 | 
			
		||||
        self.possible_versions = collections.defaultdict(set)
 | 
			
		||||
        self.deprecated_versions = collections.defaultdict(set)
 | 
			
		||||
        self.assumptions: List[Tuple["clingo.Symbol", bool]] = []  # type: ignore[name-defined]
 | 
			
		||||
        self.declared_versions: Dict[str, List[DeclaredVersion]] = collections.defaultdict(list)
 | 
			
		||||
        self.possible_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(set)
 | 
			
		||||
        self.deprecated_versions: Dict[str, Set[GitOrStandardVersion]] = collections.defaultdict(
 | 
			
		||||
            set
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.possible_virtuals = None
 | 
			
		||||
        self.possible_compilers = []
 | 
			
		||||
        self.possible_oses = set()
 | 
			
		||||
        self.variant_values_from_specs = set()
 | 
			
		||||
        self.version_constraints = set()
 | 
			
		||||
        self.target_constraints = set()
 | 
			
		||||
        self.default_targets = []
 | 
			
		||||
        self.compiler_version_constraints = set()
 | 
			
		||||
        self.post_facts = []
 | 
			
		||||
        self.possible_compilers: List = []
 | 
			
		||||
        self.possible_oses: Set = set()
 | 
			
		||||
        self.variant_values_from_specs: Set = set()
 | 
			
		||||
        self.version_constraints: Set = set()
 | 
			
		||||
        self.target_constraints: Set = set()
 | 
			
		||||
        self.default_targets: List = []
 | 
			
		||||
        self.compiler_version_constraints: Set = set()
 | 
			
		||||
        self.post_facts: List = []
 | 
			
		||||
 | 
			
		||||
        # (ID, CompilerSpec) -> dictionary of attributes
 | 
			
		||||
        self.compiler_info = collections.defaultdict(dict)
 | 
			
		||||
        self.reusable_and_possible: ConcreteSpecsByHash = ConcreteSpecsByHash()
 | 
			
		||||
 | 
			
		||||
        self.reusable_and_possible = ConcreteSpecsByHash()
 | 
			
		||||
 | 
			
		||||
        self._id_counter = itertools.count()
 | 
			
		||||
        self._trigger_cache = collections.defaultdict(dict)
 | 
			
		||||
        self._effect_cache = collections.defaultdict(dict)
 | 
			
		||||
        self._id_counter: Iterator[int] = itertools.count()
 | 
			
		||||
        self._trigger_cache: ConditionSpecCache = collections.defaultdict(dict)
 | 
			
		||||
        self._effect_cache: ConditionSpecCache = collections.defaultdict(dict)
 | 
			
		||||
 | 
			
		||||
        # Caches to optimize the setup phase of the solver
 | 
			
		||||
        self.target_specs_cache = None
 | 
			
		||||
@@ -937,8 +943,8 @@ def __init__(self, tests=False):
 | 
			
		||||
        self.concretize_everything = True
 | 
			
		||||
 | 
			
		||||
        # Set during the call to setup
 | 
			
		||||
        self.pkgs = None
 | 
			
		||||
        self.explicitly_required_namespaces = {}
 | 
			
		||||
        self.pkgs: Set[str] = set()
 | 
			
		||||
        self.explicitly_required_namespaces: Dict[str, str] = {}
 | 
			
		||||
 | 
			
		||||
    def pkg_version_rules(self, pkg):
 | 
			
		||||
        """Output declared versions of a package.
 | 
			
		||||
@@ -1222,6 +1228,38 @@ def variant_rules(self, pkg):
 | 
			
		||||
 | 
			
		||||
            self.gen.newline()
 | 
			
		||||
 | 
			
		||||
    def _get_condition_id(
 | 
			
		||||
        self,
 | 
			
		||||
        named_cond: spack.spec.Spec,
 | 
			
		||||
        cache: ConditionSpecCache,
 | 
			
		||||
        body: bool,
 | 
			
		||||
        transform: Optional[TransformFunction] = None,
 | 
			
		||||
    ) -> int:
 | 
			
		||||
        """Get the id for one half of a condition (either a trigger or an imposed constraint).
 | 
			
		||||
 | 
			
		||||
        Construct a key from the condition spec and any associated transformation, and
 | 
			
		||||
        cache the ASP functions that they imply. The saved functions will be output
 | 
			
		||||
        later in ``trigger_rules()`` and ``effect_rules()``.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The id of the cached trigger or effect.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        pkg_cache = cache[named_cond.name]
 | 
			
		||||
 | 
			
		||||
        named_cond_key = (str(named_cond), transform)
 | 
			
		||||
        result = pkg_cache.get(named_cond_key)
 | 
			
		||||
        if result:
 | 
			
		||||
            return result[0]
 | 
			
		||||
 | 
			
		||||
        cond_id = next(self._id_counter)
 | 
			
		||||
        requirements = self.spec_clauses(named_cond, body=body)
 | 
			
		||||
        if transform:
 | 
			
		||||
            requirements = transform(named_cond, requirements)
 | 
			
		||||
        pkg_cache[named_cond_key] = (cond_id, requirements)
 | 
			
		||||
 | 
			
		||||
        return cond_id
 | 
			
		||||
 | 
			
		||||
    def condition(
 | 
			
		||||
        self,
 | 
			
		||||
        required_spec: spack.spec.Spec,
 | 
			
		||||
@@ -1247,7 +1285,8 @@ def condition(
 | 
			
		||||
        """
 | 
			
		||||
        named_cond = required_spec.copy()
 | 
			
		||||
        named_cond.name = named_cond.name or name
 | 
			
		||||
        assert named_cond.name, "must provide name for anonymous conditions!"
 | 
			
		||||
        if not named_cond.name:
 | 
			
		||||
            raise ValueError(f"Must provide a name for anonymous condition: '{named_cond}'")
 | 
			
		||||
 | 
			
		||||
        # Check if we can emit the requirements before updating the condition ID counter.
 | 
			
		||||
        # In this way, if a condition can't be emitted but the exception is handled in the caller,
 | 
			
		||||
@@ -1257,35 +1296,19 @@ def condition(
 | 
			
		||||
        self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
 | 
			
		||||
        self.gen.fact(fn.condition_reason(condition_id, msg))
 | 
			
		||||
 | 
			
		||||
        cache = self._trigger_cache[named_cond.name]
 | 
			
		||||
 | 
			
		||||
        named_cond_key = (str(named_cond), transform_required)
 | 
			
		||||
        if named_cond_key not in cache:
 | 
			
		||||
            trigger_id = next(self._id_counter)
 | 
			
		||||
            requirements = self.spec_clauses(named_cond, body=True, required_from=name)
 | 
			
		||||
 | 
			
		||||
            if transform_required:
 | 
			
		||||
                requirements = transform_required(named_cond, requirements)
 | 
			
		||||
 | 
			
		||||
            cache[named_cond_key] = (trigger_id, requirements)
 | 
			
		||||
        trigger_id, requirements = cache[named_cond_key]
 | 
			
		||||
        trigger_id = self._get_condition_id(
 | 
			
		||||
            named_cond, cache=self._trigger_cache, body=True, transform=transform_required
 | 
			
		||||
        )
 | 
			
		||||
        self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
 | 
			
		||||
 | 
			
		||||
        if not imposed_spec:
 | 
			
		||||
            return condition_id
 | 
			
		||||
 | 
			
		||||
        cache = self._effect_cache[named_cond.name]
 | 
			
		||||
        imposed_spec_key = (str(imposed_spec), transform_imposed)
 | 
			
		||||
        if imposed_spec_key not in cache:
 | 
			
		||||
            effect_id = next(self._id_counter)
 | 
			
		||||
            requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
 | 
			
		||||
 | 
			
		||||
            if transform_imposed:
 | 
			
		||||
                requirements = transform_imposed(imposed_spec, requirements)
 | 
			
		||||
 | 
			
		||||
            cache[imposed_spec_key] = (effect_id, requirements)
 | 
			
		||||
        effect_id, requirements = cache[imposed_spec_key]
 | 
			
		||||
        effect_id = self._get_condition_id(
 | 
			
		||||
            imposed_spec, cache=self._effect_cache, body=False, transform=transform_imposed
 | 
			
		||||
        )
 | 
			
		||||
        self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
 | 
			
		||||
 | 
			
		||||
        return condition_id
 | 
			
		||||
 | 
			
		||||
    def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
 | 
			
		||||
@@ -1387,23 +1410,13 @@ def virtual_preferences(self, pkg_name, func):
 | 
			
		||||
 | 
			
		||||
    def provider_defaults(self):
 | 
			
		||||
        self.gen.h2("Default virtual providers")
 | 
			
		||||
        msg = (
 | 
			
		||||
            "Internal Error: possible_virtuals is not populated. Please report to the spack"
 | 
			
		||||
            " maintainers"
 | 
			
		||||
        )
 | 
			
		||||
        assert self.possible_virtuals is not None, msg
 | 
			
		||||
        self.virtual_preferences(
 | 
			
		||||
            "all", lambda v, p, i: self.gen.fact(fn.default_provider_preference(v, p, i))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def provider_requirements(self):
 | 
			
		||||
        self.gen.h2("Requirements on virtual providers")
 | 
			
		||||
        msg = (
 | 
			
		||||
            "Internal Error: possible_virtuals is not populated. Please report to the spack"
 | 
			
		||||
            " maintainers"
 | 
			
		||||
        )
 | 
			
		||||
        parser = RequirementParser(spack.config.CONFIG)
 | 
			
		||||
        assert self.possible_virtuals is not None, msg
 | 
			
		||||
        for virtual_str in sorted(self.possible_virtuals):
 | 
			
		||||
            rules = parser.rules_from_virtual(virtual_str)
 | 
			
		||||
            if rules:
 | 
			
		||||
@@ -1602,35 +1615,57 @@ def flag_defaults(self):
 | 
			
		||||
                        fn.compiler_version_flag(compiler.name, compiler.version, name, flag)
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    def spec_clauses(self, *args, **kwargs):
 | 
			
		||||
        """Wrap a call to `_spec_clauses()` into a try/except block that
 | 
			
		||||
        raises a comprehensible error message in case of failure.
 | 
			
		||||
    def spec_clauses(
 | 
			
		||||
        self,
 | 
			
		||||
        spec: spack.spec.Spec,
 | 
			
		||||
        *,
 | 
			
		||||
        body: bool = False,
 | 
			
		||||
        transitive: bool = True,
 | 
			
		||||
        expand_hashes: bool = False,
 | 
			
		||||
        concrete_build_deps=False,
 | 
			
		||||
        required_from: Optional[str] = None,
 | 
			
		||||
    ) -> List[AspFunction]:
 | 
			
		||||
        """Wrap a call to `_spec_clauses()` into a try/except block with better error handling.
 | 
			
		||||
 | 
			
		||||
        Arguments are as for ``_spec_clauses()`` except ``required_from``.
 | 
			
		||||
 | 
			
		||||
        Arguments:
 | 
			
		||||
            required_from: name of package that caused this call.
 | 
			
		||||
        """
 | 
			
		||||
        requestor = kwargs.pop("required_from", None)
 | 
			
		||||
        try:
 | 
			
		||||
            clauses = self._spec_clauses(*args, **kwargs)
 | 
			
		||||
            clauses = self._spec_clauses(
 | 
			
		||||
                spec,
 | 
			
		||||
                body=body,
 | 
			
		||||
                transitive=transitive,
 | 
			
		||||
                expand_hashes=expand_hashes,
 | 
			
		||||
                concrete_build_deps=concrete_build_deps,
 | 
			
		||||
            )
 | 
			
		||||
        except RuntimeError as exc:
 | 
			
		||||
            msg = str(exc)
 | 
			
		||||
            if requestor:
 | 
			
		||||
                msg += ' [required from package "{0}"]'.format(requestor)
 | 
			
		||||
            if required_from:
 | 
			
		||||
                msg += f" [required from package '{required_from}']"
 | 
			
		||||
            raise RuntimeError(msg)
 | 
			
		||||
        return clauses
 | 
			
		||||
 | 
			
		||||
    def _spec_clauses(
 | 
			
		||||
        self, spec, body=False, transitive=True, expand_hashes=False, concrete_build_deps=False
 | 
			
		||||
    ):
 | 
			
		||||
        self,
 | 
			
		||||
        spec: spack.spec.Spec,
 | 
			
		||||
        *,
 | 
			
		||||
        body: bool = False,
 | 
			
		||||
        transitive: bool = True,
 | 
			
		||||
        expand_hashes: bool = False,
 | 
			
		||||
        concrete_build_deps: bool = False,
 | 
			
		||||
    ) -> List[AspFunction]:
 | 
			
		||||
        """Return a list of clauses for a spec mandates are true.
 | 
			
		||||
 | 
			
		||||
        Arguments:
 | 
			
		||||
            spec (spack.spec.Spec): the spec to analyze
 | 
			
		||||
            body (bool): if True, generate clauses to be used in rule bodies
 | 
			
		||||
                (final values) instead of rule heads (setters).
 | 
			
		||||
            transitive (bool): if False, don't generate clauses from
 | 
			
		||||
                dependencies (default True)
 | 
			
		||||
            expand_hashes (bool): if True, descend into hashes of concrete specs
 | 
			
		||||
                (default False)
 | 
			
		||||
            concrete_build_deps (bool): if False, do not include pure build deps
 | 
			
		||||
                of concrete specs (as they have no effect on runtime constraints)
 | 
			
		||||
            spec: the spec to analyze
 | 
			
		||||
            body: if True, generate clauses to be used in rule bodies (final values) instead
 | 
			
		||||
                of rule heads (setters).
 | 
			
		||||
            transitive: if False, don't generate clauses from dependencies (default True)
 | 
			
		||||
            expand_hashes: if True, descend into hashes of concrete specs (default False)
 | 
			
		||||
            concrete_build_deps: if False, do not include pure build deps of concrete specs
 | 
			
		||||
                (as they have no effect on runtime constraints)
 | 
			
		||||
 | 
			
		||||
        Normally, if called with ``transitive=True``, ``spec_clauses()`` just generates
 | 
			
		||||
        hashes for the dependency requirements of concrete specs. If ``expand_hashes``
 | 
			
		||||
@@ -1640,7 +1675,7 @@ def _spec_clauses(
 | 
			
		||||
        """
 | 
			
		||||
        clauses = []
 | 
			
		||||
 | 
			
		||||
        f = _Body if body else _Head
 | 
			
		||||
        f: Union[Type[_Head], Type[_Body]] = _Body if body else _Head
 | 
			
		||||
 | 
			
		||||
        if spec.name:
 | 
			
		||||
            clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
 | 
			
		||||
@@ -1729,8 +1764,9 @@ def _spec_clauses(
 | 
			
		||||
        # dependencies
 | 
			
		||||
        if spec.concrete:
 | 
			
		||||
            # older specs do not have package hashes, so we have to do this carefully
 | 
			
		||||
            if getattr(spec, "_package_hash", None):
 | 
			
		||||
                clauses.append(fn.attr("package_hash", spec.name, spec._package_hash))
 | 
			
		||||
            package_hash = getattr(spec, "_package_hash", None)
 | 
			
		||||
            if package_hash:
 | 
			
		||||
                clauses.append(fn.attr("package_hash", spec.name, package_hash))
 | 
			
		||||
            clauses.append(fn.attr("hash", spec.name, spec.dag_hash()))
 | 
			
		||||
 | 
			
		||||
        edges = spec.edges_from_dependents()
 | 
			
		||||
@@ -1789,7 +1825,7 @@ def _spec_clauses(
 | 
			
		||||
        return clauses
 | 
			
		||||
 | 
			
		||||
    def define_package_versions_and_validate_preferences(
 | 
			
		||||
        self, possible_pkgs, *, require_checksum: bool, allow_deprecated: bool
 | 
			
		||||
        self, possible_pkgs: Set[str], *, require_checksum: bool, allow_deprecated: bool
 | 
			
		||||
    ):
 | 
			
		||||
        """Declare any versions in specs not declared in packages."""
 | 
			
		||||
        packages_yaml = spack.config.get("packages")
 | 
			
		||||
@@ -1822,7 +1858,7 @@ def define_package_versions_and_validate_preferences(
 | 
			
		||||
            if pkg_name not in packages_yaml or "version" not in packages_yaml[pkg_name]:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            version_defs = []
 | 
			
		||||
            version_defs: List[GitOrStandardVersion] = []
 | 
			
		||||
 | 
			
		||||
            for vstr in packages_yaml[pkg_name]["version"]:
 | 
			
		||||
                v = vn.ver(vstr)
 | 
			
		||||
@@ -2033,13 +2069,6 @@ def target_defaults(self, specs):
 | 
			
		||||
 | 
			
		||||
    def virtual_providers(self):
 | 
			
		||||
        self.gen.h2("Virtual providers")
 | 
			
		||||
        msg = (
 | 
			
		||||
            "Internal Error: possible_virtuals is not populated. Please report to the spack"
 | 
			
		||||
            " maintainers"
 | 
			
		||||
        )
 | 
			
		||||
        assert self.possible_virtuals is not None, msg
 | 
			
		||||
 | 
			
		||||
        # what provides what
 | 
			
		||||
        for vspec in sorted(self.possible_virtuals):
 | 
			
		||||
            self.gen.fact(fn.virtual(vspec))
 | 
			
		||||
        self.gen.newline()
 | 
			
		||||
@@ -2236,7 +2265,7 @@ def define_concrete_input_specs(self, specs, possible):
 | 
			
		||||
 | 
			
		||||
    def setup(
 | 
			
		||||
        self,
 | 
			
		||||
        specs: Sequence[spack.spec.Spec],
 | 
			
		||||
        specs: List[spack.spec.Spec],
 | 
			
		||||
        *,
 | 
			
		||||
        reuse: Optional[List[spack.spec.Spec]] = None,
 | 
			
		||||
        allow_deprecated: bool = False,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user