Compare commits
1 Commits
develop
...
refactor-s
Author | SHA1 | Date | |
---|---|---|---|
![]() |
39c074ff79 |
@ -13,7 +13,7 @@
|
||||
import re
|
||||
import types
|
||||
import warnings
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
import archspec.cpu
|
||||
|
||||
@ -338,12 +338,36 @@ def __getattr__(self, name):
|
||||
|
||||
fn = AspFunctionBuilder()
|
||||
|
||||
TransformFunction = Callable[[spack.spec.Spec, List[AspFunction]], List[AspFunction]]
|
||||
TransformFunction = Callable[
|
||||
[spack.spec.Spec, spack.spec.Spec, List[AspFunction]], List[AspFunction]
|
||||
]
|
||||
|
||||
|
||||
def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunction]:
|
||||
def transform(
|
||||
required: spack.spec.Spec,
|
||||
imposed: spack.spec.Spec,
|
||||
clauses: List[AspFunction],
|
||||
transformations: Optional[List[TransformFunction]],
|
||||
) -> List[AspFunction]:
|
||||
"""Apply a list of TransformFunctions in order."""
|
||||
if transformations is None:
|
||||
return clauses
|
||||
|
||||
for func in transformations:
|
||||
clauses = func(required, imposed, clauses)
|
||||
return clauses
|
||||
|
||||
|
||||
def cond_key(spec, transforms):
|
||||
"""Key generator for caching triggers and effects"""
|
||||
return (str(spec), None) if transforms is None else (str(spec), tuple(transforms))
|
||||
|
||||
|
||||
def remove_node(
|
||||
required: spack.spec.Spec, imposed: spack.spec.Spec, functions: List[AspFunction]
|
||||
) -> List[AspFunction]:
|
||||
"""Transformation that removes all "node" and "virtual_node" from the input list of facts."""
|
||||
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts))
|
||||
return [func for func in functions if func.args[0] not in ("node", "virtual_node")]
|
||||
|
||||
|
||||
def _create_counter(specs, tests):
|
||||
@ -1501,14 +1525,26 @@ def variant_rules(self, pkg):
|
||||
|
||||
self.gen.newline()
|
||||
|
||||
def _lookup_condition_id(self, condition, transforms, factory, cache_by_name):
|
||||
"""Look up or create a condition in a trigger/effect cache."""
|
||||
key = cond_key(condition, transforms)
|
||||
cache = cache_by_name[condition.name]
|
||||
|
||||
pair = cache.get(key)
|
||||
if pair is None:
|
||||
pair = cache[key] = (next(self._id_counter), factory())
|
||||
|
||||
id, _ = pair
|
||||
return id
|
||||
|
||||
def condition(
|
||||
self,
|
||||
required_spec: spack.spec.Spec,
|
||||
imposed_spec: Optional[spack.spec.Spec] = None,
|
||||
name: Optional[str] = None,
|
||||
msg: Optional[str] = None,
|
||||
transform_required: Optional[TransformFunction] = None,
|
||||
transform_imposed: Optional[TransformFunction] = remove_node,
|
||||
transform_required: Optional[List[TransformFunction]] = None,
|
||||
transform_imposed: Optional[List[TransformFunction]] = [remove_node],
|
||||
):
|
||||
"""Generate facts for a dependency or virtual provider condition.
|
||||
|
||||
@ -1536,35 +1572,27 @@ def condition(
|
||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
|
||||
self.gen.fact(fn.condition_reason(condition_id, msg))
|
||||
|
||||
cache = self._trigger_cache[named_cond.name]
|
||||
|
||||
named_cond_key = (str(named_cond), transform_required)
|
||||
if named_cond_key not in cache:
|
||||
trigger_id = next(self._id_counter)
|
||||
def make_requirements():
|
||||
requirements = self.spec_clauses(named_cond, body=True, required_from=name)
|
||||
return transform(named_cond, imposed_spec, requirements, transform_required)
|
||||
|
||||
if transform_required:
|
||||
requirements = transform_required(named_cond, requirements)
|
||||
|
||||
cache[named_cond_key] = (trigger_id, requirements)
|
||||
trigger_id, requirements = cache[named_cond_key]
|
||||
trigger_id = self._lookup_condition_id(
|
||||
named_cond, transform_required, make_requirements, self._trigger_cache
|
||||
)
|
||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
|
||||
|
||||
if not imposed_spec:
|
||||
return condition_id
|
||||
|
||||
cache = self._effect_cache[named_cond.name]
|
||||
imposed_spec_key = (str(imposed_spec), transform_imposed)
|
||||
if imposed_spec_key not in cache:
|
||||
effect_id = next(self._id_counter)
|
||||
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
|
||||
def make_impositions():
|
||||
impositions = self.spec_clauses(imposed_spec, body=False, required_from=name)
|
||||
return transform(named_cond, imposed_spec, impositions, transform_imposed)
|
||||
|
||||
if transform_imposed:
|
||||
requirements = transform_imposed(imposed_spec, requirements)
|
||||
|
||||
cache[imposed_spec_key] = (effect_id, requirements)
|
||||
effect_id, requirements = cache[imposed_spec_key]
|
||||
effect_id = self._lookup_condition_id(
|
||||
imposed_spec, transform_imposed, make_impositions, self._effect_cache
|
||||
)
|
||||
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
|
||||
|
||||
return condition_id
|
||||
|
||||
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
|
||||
@ -1627,14 +1655,12 @@ def package_dependencies_rules(self, pkg):
|
||||
else:
|
||||
pass
|
||||
|
||||
def track_dependencies(input_spec, requirements):
|
||||
return requirements + [fn.attr("track_dependencies", input_spec.name)]
|
||||
def track_dependencies(required, imposed, requirements):
|
||||
return requirements + [fn.attr("track_dependencies", required.name)]
|
||||
|
||||
def dependency_holds(input_spec, requirements):
|
||||
return remove_node(input_spec, requirements) + [
|
||||
fn.attr(
|
||||
"dependency_holds", pkg.name, input_spec.name, dt.flag_to_string(t)
|
||||
)
|
||||
def dependency_holds(required, imposed, impositions):
|
||||
return impositions + [
|
||||
fn.attr("dependency_holds", pkg.name, imposed.name, dt.flag_to_string(t))
|
||||
for t in dt.ALL_FLAGS
|
||||
if t & depflag
|
||||
]
|
||||
@ -1644,8 +1670,8 @@ def dependency_holds(input_spec, requirements):
|
||||
dep.spec,
|
||||
name=pkg.name,
|
||||
msg=msg,
|
||||
transform_required=track_dependencies,
|
||||
transform_imposed=dependency_holds,
|
||||
transform_required=[track_dependencies],
|
||||
transform_imposed=[remove_node, dependency_holds],
|
||||
)
|
||||
|
||||
self.gen.newline()
|
||||
@ -1751,7 +1777,7 @@ def emit_facts_from_requirement_rules(self, rules: List[RequirementRule]):
|
||||
required_spec=when_spec,
|
||||
imposed_spec=spec,
|
||||
name=pkg_name,
|
||||
transform_imposed=transform,
|
||||
transform_imposed=[transform],
|
||||
msg=f"{spec_str} is a requirement for package {pkg_name}",
|
||||
)
|
||||
except Exception as e:
|
||||
@ -1816,14 +1842,14 @@ def external_packages(self):
|
||||
for local_idx, spec in enumerate(external_specs):
|
||||
msg = "%s available as external when satisfying %s" % (spec.name, spec)
|
||||
|
||||
def external_imposition(input_spec, _):
|
||||
return [fn.attr("external_conditions_hold", input_spec.name, local_idx)]
|
||||
def external_imposition(required, imposed, _):
|
||||
return [fn.attr("external_conditions_hold", imposed.name, local_idx)]
|
||||
|
||||
self.condition(
|
||||
spec,
|
||||
spack.spec.Spec(spec.name),
|
||||
msg=msg,
|
||||
transform_imposed=external_imposition,
|
||||
transform_imposed=[external_imposition],
|
||||
)
|
||||
self.possible_versions[spec.name].add(spec.version)
|
||||
self.gen.newline()
|
||||
|
Loading…
Reference in New Issue
Block a user