From 9ffb642b95bbb2f633fe44e2c5b1bd0bf482ab70 Mon Sep 17 00:00:00 2001 From: Todd Gamblin Date: Mon, 2 Jan 2023 17:34:10 -0800 Subject: [PATCH] refactor: use `AspFunction` consistently before and after solve Currently we create `AspFunction` objects as inputs to solves, but we don't use them when extracting symbols from clingo solves. Use them more consistently in both scenarios, and simplify the code. --- lib/spack/spack/cmd/diff.py | 12 ++----- lib/spack/spack/solver/asp.py | 66 +++++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/lib/spack/spack/cmd/diff.py b/lib/spack/spack/cmd/diff.py index 5206a246c0b..402b503c561 100644 --- a/lib/spack/spack/cmd/diff.py +++ b/lib/spack/spack/cmd/diff.py @@ -46,14 +46,6 @@ def setup_parser(subparser): ) -def shift(asp_function): - """Transforms ``attr("foo", "bar")`` into ``foo("bar")``.""" - if not asp_function.args: - raise ValueError(f"Can't shift ASP function with no arguments: {str(asp_function)}") - first, *rest = asp_function.args - return asp.AspFunction(first, rest) - - def compare_specs(a, b, to_string=False, color=None): """ Generate a comparison, including diffs (for each side) and an intersection. @@ -79,7 +71,7 @@ def compare_specs(a, b, to_string=False, color=None): # get facts for specs, making sure to include build dependencies of concrete # specs and to descend into dependency hashes so we include all facts. a_facts = set( - shift(func) + func.shift() for func in setup.spec_clauses( a, body=True, @@ -89,7 +81,7 @@ def compare_specs(a, b, to_string=False, color=None): if func.name == "attr" ) b_facts = set( - shift(func) + func.shift() for func in setup.spec_clauses( b, body=True, diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py index 74e3a538775..1dc3a09c6f5 100644 --- a/lib/spack/spack/solver/asp.py +++ b/lib/spack/spack/solver/asp.py @@ -159,7 +159,7 @@ def build_criteria_names(costs, opt_criteria): """Construct an ordered mapping from criteria names to costs.""" # ensure names of all criteria are unique - names = {name for _, name in opt_criteria} + names = {criterion.args[0] for criterion in opt_criteria} assert len(names) == len(opt_criteria), "names of optimization criteria must be unique" # costs contains: @@ -186,7 +186,7 @@ def build_criteria_names(costs, opt_criteria): # list of build cost, reuse cost, and name of each criterion criteria: List[Tuple[int, int, str]] = [] - for i, (priority, name) in enumerate(opt_criteria): + for i, (priority, name) in enumerate(c.args for c in opt_criteria): priority = int(priority) build_cost = ordered_costs[i] reuse_cost = ordered_costs[i + n_build_criteria] if priority < 100_000 else None @@ -279,6 +279,25 @@ def argify(arg): return clingo.Function(self.name, [argify(arg) for arg in self.args], positive=positive) + @staticmethod + def from_symbol(symbol): + def deargify(arg): + if arg.type is clingo.SymbolType.Number: + return arg.number + elif arg.type is clingo.SymbolType.String and arg.string in ("True", "False"): + return arg.string == "True" + else: + return arg.string + + return AspFunction(symbol.name, [deargify(arg) for arg in symbol.arguments]) + + def shift(self): + """Transforms ``attr("foo", "bar")`` into ``foo("bar")``.""" + if not self.args: + raise ValueError(f"Can't shift ASP function with no arguments: {str(self)}") + first, *rest = self.args + return AspFunction(first, rest) + def __str__(self): return "%s(%s)" % (self.name, ", ".join(str(_id(arg)) for arg in self.args)) @@ -565,13 +584,13 @@ def stringify(sym): return sym.string or str(sym) -def extract_args(model, predicate_name): - """Extract the arguments to predicates with the provided name from a model. +def extract_functions(model, function_name): + """Extract ASP functions with the given name from a model. - Pull out all the predicates with name ``predicate_name`` from the model, and return - their stringified arguments as tuples. + Pull out all the functions with name ``function_name`` from the model, and return them as + ``AspFunction`` objects. """ - return [stringify(sym.arguments) for sym in model if sym.name == predicate_name] + return [AspFunction.from_symbol(sym) for sym in model if sym.name == function_name] class PyclingoDriver(object): @@ -744,21 +763,21 @@ def on_model(model): min_cost, best_model = min(models) # first check for errors - error_args = extract_args(best_model, "error") + error_args = [fn.args for fn in extract_functions(best_model, "error")] errors = sorted((int(priority), msg, args) for priority, msg, *args in error_args) for _, msg, args in errors: self.handle_error(msg, *args) # build specs from spec attributes in the model - spec_attrs = [(name, tuple(rest)) for name, *rest in extract_args(best_model, "attr")] + spec_attrs = extract_functions(best_model, "attr") answers = builder.build_specs(spec_attrs) # add best spec to the results result.answers.append((list(min_cost), 0, answers)) # get optimization criteria - criteria_args = extract_args(best_model, "opt_criterion") - result.criteria = build_criteria_names(min_cost, criteria_args) + criteria = extract_functions(best_model, "opt_criterion") + result.criteria = build_criteria_names(min_cost, criteria) # record the number of models the solver considered result.nmodels = len(models) @@ -2305,7 +2324,7 @@ def deprecated(self, pkg, version): tty.warn(msg.format(pkg, version)) @staticmethod - def sort_fn(function_tuple): + def sort_fn(function): """Ensure attributes are evaluated in the correct order. hash attributes are handled first, since they imply entire concrete specs @@ -2315,7 +2334,7 @@ def sort_fn(function_tuple): the concrete specs on which they depend because all nodes are fully constructed before we consider which ones are external. """ - name = function_tuple[0] + name = function.args[0] if name == "hash": return (-5, 0) elif name == "node": @@ -2329,23 +2348,24 @@ def sort_fn(function_tuple): else: return (-1, 0) - def build_specs(self, function_tuples): + def build_specs(self, functions): # Functions don't seem to be in particular order in output. Sort # them here so that directives that build objects (like node and # node_compiler) are called in the right order. - self.function_tuples = sorted(set(function_tuples), key=self.sort_fn) + self.functions = sorted(set(functions), key=self.sort_fn) self._specs = {} - for name, args in self.function_tuples: - if SpecBuilder.ignored_attributes.match(name): + for attr in self.functions: + fn = attr.shift() # attr("foo", "bar") -> foo("bar") + + if SpecBuilder.ignored_attributes.match(fn.name): continue - action = getattr(self, name, None) + action = getattr(self, fn.name, None) # print out unknown actions so we can display them for debugging if not action: - msg = 'UNKNOWN SYMBOL: attr("%s", %s)' % (name, ", ".join(str(a) for a in args)) - tty.debug(msg) + tty.debug(f"UNKNOWN SYMBOL: {attr}") continue msg = ( @@ -2357,8 +2377,8 @@ def build_specs(self, function_tuples): # ignore predicates on virtual packages, as they're used for # solving but don't construct anything. Do not ignore error # predicates on virtual packages. - if name != "error": - pkg = args[0] + if fn.name != "error": + pkg = fn.args[0] if spack.repo.path.is_virtual(pkg): continue @@ -2368,7 +2388,7 @@ def build_specs(self, function_tuples): if spec and spec.concrete: continue - action(*args) + action(*fn.args) # namespace assignment is done after the fact, as it is not # currently part of the solve