Compare commits

...

1 Commits

Author SHA1 Message Date
Todd Gamblin
39c074ff79
solver: refactor transforms in condition generation
- [x] allow caller of `condition()` to pass lists of transforms
- [x] all transform functions now take trigger *and* effect as parameters
- [x] add some utility functions to simplify `condition()`
2023-12-14 01:40:47 -08:00

View File

@ -13,7 +13,7 @@
import re
import types
import warnings
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
import archspec.cpu
@ -338,12 +338,36 @@ def __getattr__(self, name):
fn = AspFunctionBuilder()
TransformFunction = Callable[[spack.spec.Spec, List[AspFunction]], List[AspFunction]]
TransformFunction = Callable[
[spack.spec.Spec, spack.spec.Spec, List[AspFunction]], List[AspFunction]
]
def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunction]:
def transform(
required: spack.spec.Spec,
imposed: spack.spec.Spec,
clauses: List[AspFunction],
transformations: Optional[List[TransformFunction]],
) -> List[AspFunction]:
"""Apply a list of TransformFunctions in order."""
if transformations is None:
return clauses
for func in transformations:
clauses = func(required, imposed, clauses)
return clauses
def cond_key(spec, transforms):
"""Key generator for caching triggers and effects"""
return (str(spec), None) if transforms is None else (str(spec), tuple(transforms))
def remove_node(
required: spack.spec.Spec, imposed: spack.spec.Spec, functions: List[AspFunction]
) -> List[AspFunction]:
"""Transformation that removes all "node" and "virtual_node" from the input list of facts."""
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
return [func for func in functions if func.args[0] not in ("node", "virtual_node")]
def _create_counter(specs, tests):
@ -1501,14 +1525,26 @@ def variant_rules(self, pkg):
self.gen.newline()
def _lookup_condition_id(self, condition, transforms, factory, cache_by_name):
"""Look up or create a condition in a trigger/effect cache."""
key = cond_key(condition, transforms)
cache = cache_by_name[condition.name]
pair = cache.get(key)
if pair is None:
pair = cache[key] = (next(self._id_counter), factory())
id, _ = pair
return id
def condition(
self,
required_spec: spack.spec.Spec,
imposed_spec: Optional[spack.spec.Spec] = None,
name: Optional[str] = None,
msg: Optional[str] = None,
transform_required: Optional[TransformFunction] = None,
transform_imposed: Optional[TransformFunction] = remove_node,
transform_required: Optional[List[TransformFunction]] = None,
transform_imposed: Optional[List[TransformFunction]] = [remove_node],
):
"""Generate facts for a dependency or virtual provider condition.
@ -1536,35 +1572,27 @@ def condition(
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
self.gen.fact(fn.condition_reason(condition_id, msg))
cache = self._trigger_cache[named_cond.name]
named_cond_key = (str(named_cond), transform_required)
if named_cond_key not in cache:
trigger_id = next(self._id_counter)
def make_requirements():
requirements = self.spec_clauses(named_cond, body=True, required_from=name)
return transform(named_cond, imposed_spec, requirements, transform_required)
if transform_required:
requirements = transform_required(named_cond, requirements)
cache[named_cond_key] = (trigger_id, requirements)
trigger_id, requirements = cache[named_cond_key]
trigger_id = self._lookup_condition_id(
named_cond, transform_required, make_requirements, self._trigger_cache
)
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
if not imposed_spec:
return condition_id
cache = self._effect_cache[named_cond.name]
imposed_spec_key = (str(imposed_spec), transform_imposed)
if imposed_spec_key not in cache:
effect_id = next(self._id_counter)
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
def make_impositions():
impositions = self.spec_clauses(imposed_spec, body=False, required_from=name)
return transform(named_cond, imposed_spec, impositions, transform_imposed)
if transform_imposed:
requirements = transform_imposed(imposed_spec, requirements)
cache[imposed_spec_key] = (effect_id, requirements)
effect_id, requirements = cache[imposed_spec_key]
effect_id = self._lookup_condition_id(
imposed_spec, transform_imposed, make_impositions, self._effect_cache
)
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
return condition_id
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
@ -1627,14 +1655,12 @@ def package_dependencies_rules(self, pkg):
else:
pass
def track_dependencies(input_spec, requirements):
return requirements + [fn.attr("track_dependencies", input_spec.name)]
def track_dependencies(required, imposed, requirements):
return requirements + [fn.attr("track_dependencies", required.name)]
def dependency_holds(input_spec, requirements):
return remove_node(input_spec, requirements) + [
fn.attr(
"dependency_holds", pkg.name, input_spec.name, dt.flag_to_string(t)
)
def dependency_holds(required, imposed, impositions):
return impositions + [
fn.attr("dependency_holds", pkg.name, imposed.name, dt.flag_to_string(t))
for t in dt.ALL_FLAGS
if t & depflag
]
@ -1644,8 +1670,8 @@ def dependency_holds(input_spec, requirements):
dep.spec,
name=pkg.name,
msg=msg,
transform_required=track_dependencies,
transform_imposed=dependency_holds,
transform_required=[track_dependencies],
transform_imposed=[remove_node, dependency_holds],
)
self.gen.newline()
@ -1751,7 +1777,7 @@ def emit_facts_from_requirement_rules(self, rules: List[RequirementRule]):
required_spec=when_spec,
imposed_spec=spec,
name=pkg_name,
transform_imposed=transform,
transform_imposed=[transform],
msg=f"{spec_str} is a requirement for package {pkg_name}",
)
except Exception as e:
@ -1816,14 +1842,14 @@ def external_packages(self):
for local_idx, spec in enumerate(external_specs):
msg = "%s available as external when satisfying %s" % (spec.name, spec)
def external_imposition(input_spec, _):
return [fn.attr("external_conditions_hold", input_spec.name, local_idx)]
def external_imposition(required, imposed, _):
return [fn.attr("external_conditions_hold", imposed.name, local_idx)]
self.condition(
spec,
spack.spec.Spec(spec.name),
msg=msg,
transform_imposed=external_imposition,
transform_imposed=[external_imposition],
)
self.possible_versions[spec.name].add(spec.version)
self.gen.newline()