Compare commits

...

1 Commits

Author SHA1 Message Date
Todd Gamblin
39c074ff79
solver: refactor transforms in condition generation
- [x] allow caller of `condition()` to pass lists of transforms
- [x] all transform functions now take trigger *and* effect as parameters
- [x] add some utility functions to simplify `condition()`
2023-12-14 01:40:47 -08:00

View File

@ -13,7 +13,7 @@
import re import re
import types import types
import warnings import warnings
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
import archspec.cpu import archspec.cpu
@ -338,12 +338,36 @@ def __getattr__(self, name):
fn = AspFunctionBuilder() fn = AspFunctionBuilder()
TransformFunction = Callable[[spack.spec.Spec, List[AspFunction]], List[AspFunction]] TransformFunction = Callable[
[spack.spec.Spec, spack.spec.Spec, List[AspFunction]], List[AspFunction]
]
def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunction]: def transform(
required: spack.spec.Spec,
imposed: spack.spec.Spec,
clauses: List[AspFunction],
transformations: Optional[List[TransformFunction]],
) -> List[AspFunction]:
"""Apply a list of TransformFunctions in order."""
if transformations is None:
return clauses
for func in transformations:
clauses = func(required, imposed, clauses)
return clauses
def cond_key(spec, transforms):
"""Key generator for caching triggers and effects"""
return (str(spec), None) if transforms is None else (str(spec), tuple(transforms))
def remove_node(
required: spack.spec.Spec, imposed: spack.spec.Spec, functions: List[AspFunction]
) -> List[AspFunction]:
"""Transformation that removes all "node" and "virtual_node" from the input list of facts.""" """Transformation that removes all "node" and "virtual_node" from the input list of facts."""
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts)) return [func for func in functions if func.args[0] not in ("node", "virtual_node")]
def _create_counter(specs, tests): def _create_counter(specs, tests):
@ -1501,14 +1525,26 @@ def variant_rules(self, pkg):
self.gen.newline() self.gen.newline()
def _lookup_condition_id(self, condition, transforms, factory, cache_by_name):
"""Look up or create a condition in a trigger/effect cache."""
key = cond_key(condition, transforms)
cache = cache_by_name[condition.name]
pair = cache.get(key)
if pair is None:
pair = cache[key] = (next(self._id_counter), factory())
id, _ = pair
return id
def condition( def condition(
self, self,
required_spec: spack.spec.Spec, required_spec: spack.spec.Spec,
imposed_spec: Optional[spack.spec.Spec] = None, imposed_spec: Optional[spack.spec.Spec] = None,
name: Optional[str] = None, name: Optional[str] = None,
msg: Optional[str] = None, msg: Optional[str] = None,
transform_required: Optional[TransformFunction] = None, transform_required: Optional[List[TransformFunction]] = None,
transform_imposed: Optional[TransformFunction] = remove_node, transform_imposed: Optional[List[TransformFunction]] = [remove_node],
): ):
"""Generate facts for a dependency or virtual provider condition. """Generate facts for a dependency or virtual provider condition.
@ -1536,35 +1572,27 @@ def condition(
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
self.gen.fact(fn.condition_reason(condition_id, msg)) self.gen.fact(fn.condition_reason(condition_id, msg))
cache = self._trigger_cache[named_cond.name] def make_requirements():
named_cond_key = (str(named_cond), transform_required)
if named_cond_key not in cache:
trigger_id = next(self._id_counter)
requirements = self.spec_clauses(named_cond, body=True, required_from=name) requirements = self.spec_clauses(named_cond, body=True, required_from=name)
return transform(named_cond, imposed_spec, requirements, transform_required)
if transform_required: trigger_id = self._lookup_condition_id(
requirements = transform_required(named_cond, requirements) named_cond, transform_required, make_requirements, self._trigger_cache
)
cache[named_cond_key] = (trigger_id, requirements)
trigger_id, requirements = cache[named_cond_key]
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
if not imposed_spec: if not imposed_spec:
return condition_id return condition_id
cache = self._effect_cache[named_cond.name] def make_impositions():
imposed_spec_key = (str(imposed_spec), transform_imposed) impositions = self.spec_clauses(imposed_spec, body=False, required_from=name)
if imposed_spec_key not in cache: return transform(named_cond, imposed_spec, impositions, transform_imposed)
effect_id = next(self._id_counter)
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
if transform_imposed: effect_id = self._lookup_condition_id(
requirements = transform_imposed(imposed_spec, requirements) imposed_spec, transform_imposed, make_impositions, self._effect_cache
)
cache[imposed_spec_key] = (effect_id, requirements)
effect_id, requirements = cache[imposed_spec_key]
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
return condition_id return condition_id
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False): def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
@ -1627,14 +1655,12 @@ def package_dependencies_rules(self, pkg):
else: else:
pass pass
def track_dependencies(input_spec, requirements): def track_dependencies(required, imposed, requirements):
return requirements + [fn.attr("track_dependencies", input_spec.name)] return requirements + [fn.attr("track_dependencies", required.name)]
def dependency_holds(input_spec, requirements): def dependency_holds(required, imposed, impositions):
return remove_node(input_spec, requirements) + [ return impositions + [
fn.attr( fn.attr("dependency_holds", pkg.name, imposed.name, dt.flag_to_string(t))
"dependency_holds", pkg.name, input_spec.name, dt.flag_to_string(t)
)
for t in dt.ALL_FLAGS for t in dt.ALL_FLAGS
if t & depflag if t & depflag
] ]
@ -1644,8 +1670,8 @@ def dependency_holds(input_spec, requirements):
dep.spec, dep.spec,
name=pkg.name, name=pkg.name,
msg=msg, msg=msg,
transform_required=track_dependencies, transform_required=[track_dependencies],
transform_imposed=dependency_holds, transform_imposed=[remove_node, dependency_holds],
) )
self.gen.newline() self.gen.newline()
@ -1751,7 +1777,7 @@ def emit_facts_from_requirement_rules(self, rules: List[RequirementRule]):
required_spec=when_spec, required_spec=when_spec,
imposed_spec=spec, imposed_spec=spec,
name=pkg_name, name=pkg_name,
transform_imposed=transform, transform_imposed=[transform],
msg=f"{spec_str} is a requirement for package {pkg_name}", msg=f"{spec_str} is a requirement for package {pkg_name}",
) )
except Exception as e: except Exception as e:
@ -1816,14 +1842,14 @@ def external_packages(self):
for local_idx, spec in enumerate(external_specs): for local_idx, spec in enumerate(external_specs):
msg = "%s available as external when satisfying %s" % (spec.name, spec) msg = "%s available as external when satisfying %s" % (spec.name, spec)
def external_imposition(input_spec, _): def external_imposition(required, imposed, _):
return [fn.attr("external_conditions_hold", input_spec.name, local_idx)] return [fn.attr("external_conditions_hold", imposed.name, local_idx)]
self.condition( self.condition(
spec, spec,
spack.spec.Spec(spec.name), spack.spec.Spec(spec.name),
msg=msg, msg=msg,
transform_imposed=external_imposition, transform_imposed=[external_imposition],
) )
self.possible_versions[spec.name].add(spec.version) self.possible_versions[spec.name].add(spec.version)
self.gen.newline() self.gen.newline()