diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py index 4d9b5a0c393..68a80cfd19f 100644 --- a/lib/spack/spack/solver/asp.py +++ b/lib/spack/spack/solver/asp.py @@ -13,7 +13,7 @@ import re import types import warnings -from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union import archspec.cpu @@ -338,12 +338,36 @@ def __getattr__(self, name): fn = AspFunctionBuilder() -TransformFunction = Callable[[spack.spec.Spec, List[AspFunction]], List[AspFunction]] +TransformFunction = Callable[ + [spack.spec.Spec, spack.spec.Spec, List[AspFunction]], List[AspFunction] +] -def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunction]: +def transform( + required: spack.spec.Spec, + imposed: spack.spec.Spec, + clauses: List[AspFunction], + transformations: Optional[List[TransformFunction]], +) -> List[AspFunction]: + """Apply a list of TransformFunctions in order.""" + if transformations is None: + return clauses + + for func in transformations: + clauses = func(required, imposed, clauses) + return clauses + + +def cond_key(spec, transforms): + """Key generator for caching triggers and effects""" + return (str(spec), None) if transforms is None else (str(spec), tuple(transforms)) + + +def remove_node( + required: spack.spec.Spec, imposed: spack.spec.Spec, functions: List[AspFunction] +) -> List[AspFunction]: """Transformation that removes all "node" and "virtual_node" from the input list of facts.""" - return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts)) + return [func for func in functions if func.args[0] not in ("node", "virtual_node")] def _create_counter(specs, tests): @@ -1501,14 +1525,26 @@ def variant_rules(self, pkg): self.gen.newline() + def _lookup_condition_id(self, condition, transforms, factory, cache_by_name): + """Look up or create a condition in a trigger/effect cache.""" + key = cond_key(condition, transforms) + cache = cache_by_name[condition.name] + + pair = cache.get(key) + if pair is None: + pair = cache[key] = (next(self._id_counter), factory()) + + id, _ = pair + return id + def condition( self, required_spec: spack.spec.Spec, imposed_spec: Optional[spack.spec.Spec] = None, name: Optional[str] = None, msg: Optional[str] = None, - transform_required: Optional[TransformFunction] = None, - transform_imposed: Optional[TransformFunction] = remove_node, + transform_required: Optional[List[TransformFunction]] = None, + transform_imposed: Optional[List[TransformFunction]] = [remove_node], ): """Generate facts for a dependency or virtual provider condition. @@ -1536,35 +1572,27 @@ def condition( self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id))) self.gen.fact(fn.condition_reason(condition_id, msg)) - cache = self._trigger_cache[named_cond.name] - - named_cond_key = (str(named_cond), transform_required) - if named_cond_key not in cache: - trigger_id = next(self._id_counter) + def make_requirements(): requirements = self.spec_clauses(named_cond, body=True, required_from=name) + return transform(named_cond, imposed_spec, requirements, transform_required) - if transform_required: - requirements = transform_required(named_cond, requirements) - - cache[named_cond_key] = (trigger_id, requirements) - trigger_id, requirements = cache[named_cond_key] + trigger_id = self._lookup_condition_id( + named_cond, transform_required, make_requirements, self._trigger_cache + ) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id))) if not imposed_spec: return condition_id - cache = self._effect_cache[named_cond.name] - imposed_spec_key = (str(imposed_spec), transform_imposed) - if imposed_spec_key not in cache: - effect_id = next(self._id_counter) - requirements = self.spec_clauses(imposed_spec, body=False, required_from=name) + def make_impositions(): + impositions = self.spec_clauses(imposed_spec, body=False, required_from=name) + return transform(named_cond, imposed_spec, impositions, transform_imposed) - if transform_imposed: - requirements = transform_imposed(imposed_spec, requirements) - - cache[imposed_spec_key] = (effect_id, requirements) - effect_id, requirements = cache[imposed_spec_key] + effect_id = self._lookup_condition_id( + imposed_spec, transform_imposed, make_impositions, self._effect_cache + ) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id))) + return condition_id def impose(self, condition_id, imposed_spec, node=True, name=None, body=False): @@ -1627,14 +1655,12 @@ def package_dependencies_rules(self, pkg): else: pass - def track_dependencies(input_spec, requirements): - return requirements + [fn.attr("track_dependencies", input_spec.name)] + def track_dependencies(required, imposed, requirements): + return requirements + [fn.attr("track_dependencies", required.name)] - def dependency_holds(input_spec, requirements): - return remove_node(input_spec, requirements) + [ - fn.attr( - "dependency_holds", pkg.name, input_spec.name, dt.flag_to_string(t) - ) + def dependency_holds(required, imposed, impositions): + return impositions + [ + fn.attr("dependency_holds", pkg.name, imposed.name, dt.flag_to_string(t)) for t in dt.ALL_FLAGS if t & depflag ] @@ -1644,8 +1670,8 @@ def dependency_holds(input_spec, requirements): dep.spec, name=pkg.name, msg=msg, - transform_required=track_dependencies, - transform_imposed=dependency_holds, + transform_required=[track_dependencies], + transform_imposed=[remove_node, dependency_holds], ) self.gen.newline() @@ -1743,15 +1769,13 @@ def emit_facts_from_requirement_rules(self, rules: List[RequirementRule]): try: # With virtual we want to emit "node" and "virtual_node" in imposed specs - transform: Optional[TransformFunction] = remove_node - if virtual: - transform = None + transform = None if virtual else [remove_node] member_id = self.condition( required_spec=when_spec, imposed_spec=spec, name=pkg_name, - transform_imposed=transform, + transform_imposed=[transform], msg=f"{spec_str} is a requirement for package {pkg_name}", ) except Exception as e: @@ -1816,14 +1840,14 @@ def external_packages(self): for local_idx, spec in enumerate(external_specs): msg = "%s available as external when satisfying %s" % (spec.name, spec) - def external_imposition(input_spec, _): - return [fn.attr("external_conditions_hold", input_spec.name, local_idx)] + def external_imposition(required, imposed, _): + return [fn.attr("external_conditions_hold", imposed.name, local_idx)] self.condition( spec, spack.spec.Spec(spec.name), msg=msg, - transform_imposed=external_imposition, + transform_imposed=[external_imposition], ) self.possible_versions[spec.name].add(spec.version) self.gen.newline()