refactor: use AspFunction consistently before and after solve

Currently we create `AspFunction` objects as inputs to solves, but we don't use them
when extracting symbols from clingo solves. Use them more consistently in both
scenarios, and simplify the code.
This commit is contained in:
Todd Gamblin 2023-01-02 17:34:10 -08:00
parent 20e698c6b0
commit 9ffb642b95
No known key found for this signature in database
GPG Key ID: C16729F1AACF66C6
2 changed files with 45 additions and 33 deletions

View File

@ -46,14 +46,6 @@ def setup_parser(subparser):
)
def shift(asp_function):
"""Transforms ``attr("foo", "bar")`` into ``foo("bar")``."""
if not asp_function.args:
raise ValueError(f"Can't shift ASP function with no arguments: {str(asp_function)}")
first, *rest = asp_function.args
return asp.AspFunction(first, rest)
def compare_specs(a, b, to_string=False, color=None):
"""
Generate a comparison, including diffs (for each side) and an intersection.
@ -79,7 +71,7 @@ def compare_specs(a, b, to_string=False, color=None):
# get facts for specs, making sure to include build dependencies of concrete
# specs and to descend into dependency hashes so we include all facts.
a_facts = set(
shift(func)
func.shift()
for func in setup.spec_clauses(
a,
body=True,
@ -89,7 +81,7 @@ def compare_specs(a, b, to_string=False, color=None):
if func.name == "attr"
)
b_facts = set(
shift(func)
func.shift()
for func in setup.spec_clauses(
b,
body=True,

View File

@ -159,7 +159,7 @@ def build_criteria_names(costs, opt_criteria):
"""Construct an ordered mapping from criteria names to costs."""
# ensure names of all criteria are unique
names = {name for _, name in opt_criteria}
names = {criterion.args[0] for criterion in opt_criteria}
assert len(names) == len(opt_criteria), "names of optimization criteria must be unique"
# costs contains:
@ -186,7 +186,7 @@ def build_criteria_names(costs, opt_criteria):
# list of build cost, reuse cost, and name of each criterion
criteria: List[Tuple[int, int, str]] = []
for i, (priority, name) in enumerate(opt_criteria):
for i, (priority, name) in enumerate(c.args for c in opt_criteria):
priority = int(priority)
build_cost = ordered_costs[i]
reuse_cost = ordered_costs[i + n_build_criteria] if priority < 100_000 else None
@ -279,6 +279,25 @@ def argify(arg):
return clingo.Function(self.name, [argify(arg) for arg in self.args], positive=positive)
@staticmethod
def from_symbol(symbol):
def deargify(arg):
if arg.type is clingo.SymbolType.Number:
return arg.number
elif arg.type is clingo.SymbolType.String and arg.string in ("True", "False"):
return arg.string == "True"
else:
return arg.string
return AspFunction(symbol.name, [deargify(arg) for arg in symbol.arguments])
def shift(self):
"""Transforms ``attr("foo", "bar")`` into ``foo("bar")``."""
if not self.args:
raise ValueError(f"Can't shift ASP function with no arguments: {str(self)}")
first, *rest = self.args
return AspFunction(first, rest)
def __str__(self):
return "%s(%s)" % (self.name, ", ".join(str(_id(arg)) for arg in self.args))
@ -565,13 +584,13 @@ def stringify(sym):
return sym.string or str(sym)
def extract_args(model, predicate_name):
"""Extract the arguments to predicates with the provided name from a model.
def extract_functions(model, function_name):
"""Extract ASP functions with the given name from a model.
Pull out all the predicates with name ``predicate_name`` from the model, and return
their stringified arguments as tuples.
Pull out all the functions with name ``function_name`` from the model, and return them as
``AspFunction`` objects.
"""
return [stringify(sym.arguments) for sym in model if sym.name == predicate_name]
return [AspFunction.from_symbol(sym) for sym in model if sym.name == function_name]
class PyclingoDriver(object):
@ -744,21 +763,21 @@ def on_model(model):
min_cost, best_model = min(models)
# first check for errors
error_args = extract_args(best_model, "error")
error_args = [fn.args for fn in extract_functions(best_model, "error")]
errors = sorted((int(priority), msg, args) for priority, msg, *args in error_args)
for _, msg, args in errors:
self.handle_error(msg, *args)
# build specs from spec attributes in the model
spec_attrs = [(name, tuple(rest)) for name, *rest in extract_args(best_model, "attr")]
spec_attrs = extract_functions(best_model, "attr")
answers = builder.build_specs(spec_attrs)
# add best spec to the results
result.answers.append((list(min_cost), 0, answers))
# get optimization criteria
criteria_args = extract_args(best_model, "opt_criterion")
result.criteria = build_criteria_names(min_cost, criteria_args)
criteria = extract_functions(best_model, "opt_criterion")
result.criteria = build_criteria_names(min_cost, criteria)
# record the number of models the solver considered
result.nmodels = len(models)
@ -2305,7 +2324,7 @@ def deprecated(self, pkg, version):
tty.warn(msg.format(pkg, version))
@staticmethod
def sort_fn(function_tuple):
def sort_fn(function):
"""Ensure attributes are evaluated in the correct order.
hash attributes are handled first, since they imply entire concrete specs
@ -2315,7 +2334,7 @@ def sort_fn(function_tuple):
the concrete specs on which they depend because all nodes are fully constructed before we
consider which ones are external.
"""
name = function_tuple[0]
name = function.args[0]
if name == "hash":
return (-5, 0)
elif name == "node":
@ -2329,23 +2348,24 @@ def sort_fn(function_tuple):
else:
return (-1, 0)
def build_specs(self, function_tuples):
def build_specs(self, functions):
# Functions don't seem to be in particular order in output. Sort
# them here so that directives that build objects (like node and
# node_compiler) are called in the right order.
self.function_tuples = sorted(set(function_tuples), key=self.sort_fn)
self.functions = sorted(set(functions), key=self.sort_fn)
self._specs = {}
for name, args in self.function_tuples:
if SpecBuilder.ignored_attributes.match(name):
for attr in self.functions:
fn = attr.shift() # attr("foo", "bar") -> foo("bar")
if SpecBuilder.ignored_attributes.match(fn.name):
continue
action = getattr(self, name, None)
action = getattr(self, fn.name, None)
# print out unknown actions so we can display them for debugging
if not action:
msg = 'UNKNOWN SYMBOL: attr("%s", %s)' % (name, ", ".join(str(a) for a in args))
tty.debug(msg)
tty.debug(f"UNKNOWN SYMBOL: {attr}")
continue
msg = (
@ -2357,8 +2377,8 @@ def build_specs(self, function_tuples):
# ignore predicates on virtual packages, as they're used for
# solving but don't construct anything. Do not ignore error
# predicates on virtual packages.
if name != "error":
pkg = args[0]
if fn.name != "error":
pkg = fn.args[0]
if spack.repo.path.is_virtual(pkg):
continue
@ -2368,7 +2388,7 @@ def build_specs(self, function_tuples):
if spec and spec.concrete:
continue
action(*args)
action(*fn.args)
# namespace assignment is done after the fact, as it is not
# currently part of the solve