solver: refactor transforms in condition generation

- [x] allow caller of `condition()` to pass lists of transforms
- [x] all transform functions now take trigger *and* effect as parameters
- [x] add some utility functions to simplify `condition()`
This commit is contained in:
Todd Gamblin 2023-12-14 00:28:06 -08:00
parent a690b8c27c
commit 977f6ce65c
No known key found for this signature in database
GPG Key ID: C16729F1AACF66C6

View File

@ -13,7 +13,7 @@
import re
import types
import warnings
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
import archspec.cpu
@ -338,12 +338,36 @@ def __getattr__(self, name):
fn = AspFunctionBuilder()
TransformFunction = Callable[[spack.spec.Spec, List[AspFunction]], List[AspFunction]]
TransformFunction = Callable[
[spack.spec.Spec, spack.spec.Spec, List[AspFunction]], List[AspFunction]
]
def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunction]:
def transform(
required: spack.spec.Spec,
imposed: spack.spec.Spec,
clauses: List[AspFunction],
transformations: Optional[List[TransformFunction]],
) -> List[AspFunction]:
"""Apply a list of TransformFunctions in order."""
if transformations is None:
return clauses
for func in transformations:
clauses = func(required, imposed, clauses)
return clauses
def cond_key(spec, transforms):
"""Key generator for caching triggers and effects"""
return (str(spec), None) if transforms is None else (str(spec), tuple(transforms))
def remove_node(
required: spack.spec.Spec, imposed: spack.spec.Spec, functions: List[AspFunction]
) -> List[AspFunction]:
"""Transformation that removes all "node" and "virtual_node" from the input list of facts."""
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
return [func for func in functions if func.args[0] not in ("node", "virtual_node")]
def _create_counter(specs, tests):
@ -1501,14 +1525,26 @@ def variant_rules(self, pkg):
self.gen.newline()
def _lookup_condition_id(self, condition, transforms, factory, cache_by_name):
"""Look up or create a condition in a trigger/effect cache."""
key = cond_key(condition, transforms)
cache = cache_by_name[condition.name]
pair = cache.get(key)
if pair is None:
pair = cache[key] = (next(self._id_counter), factory())
id, _ = pair
return id
def condition(
self,
required_spec: spack.spec.Spec,
imposed_spec: Optional[spack.spec.Spec] = None,
name: Optional[str] = None,
msg: Optional[str] = None,
transform_required: Optional[TransformFunction] = None,
transform_imposed: Optional[TransformFunction] = remove_node,
transform_required: Optional[List[TransformFunction]] = None,
transform_imposed: Optional[List[TransformFunction]] = [remove_node],
):
"""Generate facts for a dependency or virtual provider condition.
@ -1536,35 +1572,27 @@ def condition(
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
self.gen.fact(fn.condition_reason(condition_id, msg))
cache = self._trigger_cache[named_cond.name]
named_cond_key = (str(named_cond), transform_required)
if named_cond_key not in cache:
trigger_id = next(self._id_counter)
def make_requirements():
requirements = self.spec_clauses(named_cond, body=True, required_from=name)
return transform(named_cond, imposed_spec, requirements, transform_required)
if transform_required:
requirements = transform_required(named_cond, requirements)
cache[named_cond_key] = (trigger_id, requirements)
trigger_id, requirements = cache[named_cond_key]
trigger_id = self._lookup_condition_id(
named_cond, transform_required, make_requirements, self._trigger_cache
)
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
if not imposed_spec:
return condition_id
cache = self._effect_cache[named_cond.name]
imposed_spec_key = (str(imposed_spec), transform_imposed)
if imposed_spec_key not in cache:
effect_id = next(self._id_counter)
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
def make_impositions():
impositions = self.spec_clauses(imposed_spec, body=False, required_from=name)
return transform(named_cond, imposed_spec, impositions, transform_imposed)
if transform_imposed:
requirements = transform_imposed(imposed_spec, requirements)
cache[imposed_spec_key] = (effect_id, requirements)
effect_id, requirements = cache[imposed_spec_key]
effect_id = self._lookup_condition_id(
imposed_spec, transform_imposed, make_impositions, self._effect_cache
)
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
return condition_id
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
@ -1627,14 +1655,12 @@ def package_dependencies_rules(self, pkg):
else:
pass
def track_dependencies(input_spec, requirements):
return requirements + [fn.attr("track_dependencies", input_spec.name)]
def track_dependencies(required, imposed, requirements):
return requirements + [fn.attr("track_dependencies", required.name)]
def dependency_holds(input_spec, requirements):
return remove_node(input_spec, requirements) + [
fn.attr(
"dependency_holds", pkg.name, input_spec.name, dt.flag_to_string(t)
)
def dependency_holds(required, imposed, impositions):
return impositions + [
fn.attr("dependency_holds", pkg.name, imposed.name, dt.flag_to_string(t))
for t in dt.ALL_FLAGS
if t & depflag
]
@ -1644,8 +1670,8 @@ def dependency_holds(input_spec, requirements):
dep.spec,
name=pkg.name,
msg=msg,
transform_required=track_dependencies,
transform_imposed=dependency_holds,
transform_required=[track_dependencies],
transform_imposed=[remove_node, dependency_holds],
)
self.gen.newline()
@ -1743,15 +1769,13 @@ def emit_facts_from_requirement_rules(self, rules: List[RequirementRule]):
try:
# With virtual we want to emit "node" and "virtual_node" in imposed specs
transform: Optional[TransformFunction] = remove_node
if virtual:
transform = None
transform = None if virtual else [remove_node]
member_id = self.condition(
required_spec=when_spec,
imposed_spec=spec,
name=pkg_name,
transform_imposed=transform,
transform_imposed=[transform],
msg=f"{spec_str} is a requirement for package {pkg_name}",
)
except Exception as e:
@ -1816,14 +1840,14 @@ def external_packages(self):
for local_idx, spec in enumerate(external_specs):
msg = "%s available as external when satisfying %s" % (spec.name, spec)
def external_imposition(input_spec, _):
return [fn.attr("external_conditions_hold", input_spec.name, local_idx)]
def external_imposition(required, imposed, _):
return [fn.attr("external_conditions_hold", imposed.name, local_idx)]
self.condition(
spec,
spack.spec.Spec(spec.name),
msg=msg,
transform_imposed=external_imposition,
transform_imposed=[external_imposition],
)
self.possible_versions[spec.name].add(spec.version)
self.gen.newline()