Compare commits

...

2 Commits

Author SHA1 Message Date
Todd Gamblin
e3fc937b32
WIP 2023-12-14 14:27:59 -08:00
Todd Gamblin
977f6ce65c
solver: refactor transforms in condition generation
- [x] allow caller of `condition()` to pass lists of transforms
- [x] all transform functions now take trigger *and* effect as parameters
- [x] add some utility functions to simplify `condition()`
2023-12-14 13:50:41 -08:00
4 changed files with 187 additions and 42 deletions

View File

@ -13,7 +13,7 @@
import re import re
import types import types
import warnings import warnings
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
import archspec.cpu import archspec.cpu
@ -311,6 +311,29 @@ def __call__(self, *args):
""" """
return AspFunction(self.name, self.args + args) return AspFunction(self.name, self.args + args)
def match(self, pattern: "AspFunction"):
"""Compare name and args of this ASP function to a match pattern.
Arguments of ``pattern`` function can be strings, arbitrary objects or ``any``:
* ``any`` matches any argument;
* ``str`` arguments are treated as regular expressions and match against the
string representation of the args of this function.
* any other object is compared with `==`.
"""
if self.name != pattern.name or len(pattern.args) > len(self.args):
return False
for parg, arg in zip(pattern.args, self.args):
if parg is any:
continue
elif isinstance(parg, str) and not re.match(parg, str(arg)):
return False
elif parg != arg:
return False
return True
def symbol(self, positive=True): def symbol(self, positive=True):
def argify(arg): def argify(arg):
if isinstance(arg, bool): if isinstance(arg, bool):
@ -338,12 +361,36 @@ def __getattr__(self, name):
fn = AspFunctionBuilder() fn = AspFunctionBuilder()
TransformFunction = Callable[[spack.spec.Spec, List[AspFunction]], List[AspFunction]] TransformFunction = Callable[
[spack.spec.Spec, spack.spec.Spec, List[AspFunction]], List[AspFunction]
]
def remove_node(spec: spack.spec.Spec, facts: List[AspFunction]) -> List[AspFunction]: def transform(
required: spack.spec.Spec,
imposed: spack.spec.Spec,
clauses: List[AspFunction],
transformations: Optional[List[TransformFunction]],
) -> List[AspFunction]:
"""Apply a list of TransformFunctions in order."""
if transformations is None:
return clauses
for func in transformations:
clauses = func(required, imposed, clauses)
return clauses
def cond_key(spec, transforms):
"""Key generator for caching triggers and effects"""
return (str(spec), None) if transforms is None else (str(spec), tuple(transforms))
def remove_node(
required: spack.spec.Spec, imposed: spack.spec.Spec, functions: List[AspFunction]
) -> List[AspFunction]:
"""Transformation that removes all "node" and "virtual_node" from the input list of facts.""" """Transformation that removes all "node" and "virtual_node" from the input list of facts."""
return list(filter(lambda x: x.args[0] not in ("node", "virtual_node"), facts)) return [func for func in functions if func.args[0] not in ("node", "virtual_node")]
def _create_counter(specs, tests): def _create_counter(specs, tests):
@ -1107,6 +1154,7 @@ def __init__(self, tests=False):
self.possible_oses = set() self.possible_oses = set()
self.variant_values_from_specs = set() self.variant_values_from_specs = set()
self.version_constraints = set() self.version_constraints = set()
self.synced_version_constraints = set()
self.target_constraints = set() self.target_constraints = set()
self.default_targets = [] self.default_targets = []
self.compiler_version_constraints = set() self.compiler_version_constraints = set()
@ -1501,14 +1549,26 @@ def variant_rules(self, pkg):
self.gen.newline() self.gen.newline()
def _lookup_condition_id(self, condition, transforms, factory, cache_by_name):
"""Look up or create a condition in a trigger/effect cache."""
key = cond_key(condition, transforms)
cache = cache_by_name[condition.name]
pair = cache.get(key)
if pair is None:
pair = cache[key] = (next(self._id_counter), factory())
id, _ = pair
return id
def condition( def condition(
self, self,
required_spec: spack.spec.Spec, required_spec: spack.spec.Spec,
imposed_spec: Optional[spack.spec.Spec] = None, imposed_spec: Optional[spack.spec.Spec] = None,
name: Optional[str] = None, name: Optional[str] = None,
msg: Optional[str] = None, msg: Optional[str] = None,
transform_required: Optional[TransformFunction] = None, transform_required: Optional[List[TransformFunction]] = None,
transform_imposed: Optional[TransformFunction] = remove_node, transform_imposed: Optional[List[TransformFunction]] = [remove_node],
): ):
"""Generate facts for a dependency or virtual provider condition. """Generate facts for a dependency or virtual provider condition.
@ -1536,35 +1596,27 @@ def condition(
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition(condition_id)))
self.gen.fact(fn.condition_reason(condition_id, msg)) self.gen.fact(fn.condition_reason(condition_id, msg))
cache = self._trigger_cache[named_cond.name] def make_requirements():
named_cond_key = (str(named_cond), transform_required)
if named_cond_key not in cache:
trigger_id = next(self._id_counter)
requirements = self.spec_clauses(named_cond, body=True, required_from=name) requirements = self.spec_clauses(named_cond, body=True, required_from=name)
return transform(named_cond, imposed_spec, requirements, transform_required)
if transform_required: trigger_id = self._lookup_condition_id(
requirements = transform_required(named_cond, requirements) named_cond, transform_required, make_requirements, self._trigger_cache
)
cache[named_cond_key] = (trigger_id, requirements)
trigger_id, requirements = cache[named_cond_key]
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_trigger(condition_id, trigger_id)))
if not imposed_spec: if not imposed_spec:
return condition_id return condition_id
cache = self._effect_cache[named_cond.name] def make_impositions():
imposed_spec_key = (str(imposed_spec), transform_imposed) impositions = self.spec_clauses(imposed_spec, body=False, required_from=name)
if imposed_spec_key not in cache: return transform(named_cond, imposed_spec, impositions, transform_imposed)
effect_id = next(self._id_counter)
requirements = self.spec_clauses(imposed_spec, body=False, required_from=name)
if transform_imposed: effect_id = self._lookup_condition_id(
requirements = transform_imposed(imposed_spec, requirements) imposed_spec, transform_imposed, make_impositions, self._effect_cache
)
cache[imposed_spec_key] = (effect_id, requirements)
effect_id, requirements = cache[imposed_spec_key]
self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id))) self.gen.fact(fn.pkg_fact(named_cond.name, fn.condition_effect(condition_id, effect_id)))
return condition_id return condition_id
def impose(self, condition_id, imposed_spec, node=True, name=None, body=False): def impose(self, condition_id, imposed_spec, node=True, name=None, body=False):
@ -1603,6 +1655,27 @@ def package_provider_rules(self, pkg):
) )
self.gen.newline() self.gen.newline()
def transform_my_version(
self, require: spack.spec.Spec, impose: spack.spec.Spec, funcs: List[AspFunction]
) -> List[AspFunction]:
"""Replace symbolic "my" version with reference to dependent's version."""
result = []
for f in funcs:
if not f.match(fn.attr("node_version_satisfies", any, r"^my\.version$")):
result.append(f)
continue
# get Version from version(Package, Version) and generate
# node_version_satisfies(dep, Version)
dep = f.args[1]
sync = fn.attr("sync", dep, "node_version_satisfies", fn.attr("version", require.name))
result.append(sync)
# remember to generate version_satisfies/3 for my.version constraints
self.synced_version_constraints.add((require.name, dep))
return result
def package_dependencies_rules(self, pkg): def package_dependencies_rules(self, pkg):
"""Translate 'depends_on' directives into ASP logic.""" """Translate 'depends_on' directives into ASP logic."""
for _, conditions in sorted(pkg.dependencies.items()): for _, conditions in sorted(pkg.dependencies.items()):
@ -1627,14 +1700,12 @@ def package_dependencies_rules(self, pkg):
else: else:
pass pass
def track_dependencies(input_spec, requirements): def track_dependencies(required, imposed, requirements):
return requirements + [fn.attr("track_dependencies", input_spec.name)] return requirements + [fn.attr("track_dependencies", required.name)]
def dependency_holds(input_spec, requirements): def dependency_holds(required, imposed, impositions):
return remove_node(input_spec, requirements) + [ return impositions + [
fn.attr( fn.attr("dependency_holds", pkg.name, imposed.name, dt.flag_to_string(t))
"dependency_holds", pkg.name, input_spec.name, dt.flag_to_string(t)
)
for t in dt.ALL_FLAGS for t in dt.ALL_FLAGS
if t & depflag if t & depflag
] ]
@ -1644,8 +1715,8 @@ def dependency_holds(input_spec, requirements):
dep.spec, dep.spec,
name=pkg.name, name=pkg.name,
msg=msg, msg=msg,
transform_required=track_dependencies, transform_required=[track_dependencies],
transform_imposed=dependency_holds, transform_imposed=[remove_node, dependency_holds, self.transform_my_version],
) )
self.gen.newline() self.gen.newline()
@ -1743,15 +1814,13 @@ def emit_facts_from_requirement_rules(self, rules: List[RequirementRule]):
try: try:
# With virtual we want to emit "node" and "virtual_node" in imposed specs # With virtual we want to emit "node" and "virtual_node" in imposed specs
transform: Optional[TransformFunction] = remove_node transform = None if virtual else [remove_node]
if virtual:
transform = None
member_id = self.condition( member_id = self.condition(
required_spec=when_spec, required_spec=when_spec,
imposed_spec=spec, imposed_spec=spec,
name=pkg_name, name=pkg_name,
transform_imposed=transform, transform_imposed=[transform],
msg=f"{spec_str} is a requirement for package {pkg_name}", msg=f"{spec_str} is a requirement for package {pkg_name}",
) )
except Exception as e: except Exception as e:
@ -1816,14 +1885,14 @@ def external_packages(self):
for local_idx, spec in enumerate(external_specs): for local_idx, spec in enumerate(external_specs):
msg = "%s available as external when satisfying %s" % (spec.name, spec) msg = "%s available as external when satisfying %s" % (spec.name, spec)
def external_imposition(input_spec, _): def external_imposition(required, imposed, _):
return [fn.attr("external_conditions_hold", input_spec.name, local_idx)] return [fn.attr("external_conditions_hold", imposed.name, local_idx)]
self.condition( self.condition(
spec, spec,
spack.spec.Spec(spec.name), spack.spec.Spec(spec.name),
msg=msg, msg=msg,
transform_imposed=external_imposition, transform_imposed=[external_imposition],
) )
self.possible_versions[spec.name].add(spec.version) self.possible_versions[spec.name].add(spec.version)
self.gen.newline() self.gen.newline()
@ -2414,6 +2483,15 @@ def generate_possible_compilers(self, specs):
def define_version_constraints(self): def define_version_constraints(self):
"""Define what version_satisfies(...) means in ASP logic.""" """Define what version_satisfies(...) means in ASP logic."""
# quadratic for now b/c we're anticipating pkg_ver being an
# expression/range/etc. right now this only does exact matches until we can
# propagate an expression from depends_on to here.
for pkg, dep in self.synced_version_constraints:
for pkg_ver in self.possible_versions[pkg]:
for dep_ver in self.possible_versions[dep]:
if dep_ver.satisfies(pkg_ver):
self.gen.fact(fn.pkg_fact(dep, fn.version_satisfies(dep_ver, pkg_ver)))
for pkg_name, versions in sorted(self.version_constraints): for pkg_name, versions in sorted(self.version_constraints):
# generate facts for each package constraint and the version # generate facts for each package constraint and the version
# that satisfies it # that satisfies it

View File

@ -366,6 +366,11 @@ attr(Name, node(X, A1), A2) :- impose(ID, PackageNode), imposed_constrai
attr(Name, node(X, A1), A2, A3) :- impose(ID, PackageNode), imposed_constraint(ID, Name, A1, A2, A3), imposed_nodes(ID, PackageNode, node(X, A1)), not multiple_nodes_attribute(Name). attr(Name, node(X, A1), A2, A3) :- impose(ID, PackageNode), imposed_constraint(ID, Name, A1, A2, A3), imposed_nodes(ID, PackageNode, node(X, A1)), not multiple_nodes_attribute(Name).
attr(Name, node(X, A1), A2, A3, A4) :- impose(ID, PackageNode), imposed_constraint(ID, Name, A1, A2, A3, A4), imposed_nodes(ID, PackageNode, node(X, A1)). attr(Name, node(X, A1), A2, A3, A4) :- impose(ID, PackageNode), imposed_constraint(ID, Name, A1, A2, A3, A4), imposed_nodes(ID, PackageNode, node(X, A1)).
attr(DepAttrName, DepNode, Value)
:- depends_on(node(X, Package), DepNode),
attr(PackageAttrName, node(X, Package), Value),
attr("sync", DepNode, DepAttrName, attr(PackageAttrName, Package)).
% For node flag sources we need to look at the condition_set of the source, since it is the dependent % For node flag sources we need to look at the condition_set of the source, since it is the dependent
% of the package on which I want to impose the constraint % of the package on which I want to impose the constraint
attr("node_flag_source", node(X, A1), A2, node(Y, A3)) attr("node_flag_source", node(X, A1), A2, node(Y, A3))

View File

@ -0,0 +1,30 @@
# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from spack.package import *
class VersionLockDep(Package):
"""version-lock-dep is depended on by version-lock with the same version"""
homepage = "http://example.com/version-lock-dep/"
url = "http://example.com/version-lock-dep.tar.gz"
version("3.2.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("3.2.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("3.1.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("3.1.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("3.0.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.2.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.2.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.1.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.1.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.0.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.2.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.2.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.1.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.1.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.0.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")

View File

@ -0,0 +1,32 @@
# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from spack.package import *
class VersionLock(Package):
"""version-lock depends on version-lock-dep with the same version"""
homepage = "http://example.com/version-lock/"
url = "http://example.com/version-lock.tar.gz"
version("3.2.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("3.2.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("3.1.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("3.1.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("3.0.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.2.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.2.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.1.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.1.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("2.0.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.2.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.2.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.1.1", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.1.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
version("1.0.0", sha256="18d459400558f4ea99527bc9786c033965a3db45bf4c6a32eefdc07aa9e306a6")
depends_on("version-lock-dep@my.version")