refactor: use AspFunction
consistently before and after solve
Currently we create `AspFunction` objects as inputs to solves, but we don't use them when extracting symbols from clingo solves. Use them more consistently in both scenarios, and simplify the code.
This commit is contained in:
parent
20e698c6b0
commit
9ffb642b95
@ -46,14 +46,6 @@ def setup_parser(subparser):
|
||||
)
|
||||
|
||||
|
||||
def shift(asp_function):
|
||||
"""Transforms ``attr("foo", "bar")`` into ``foo("bar")``."""
|
||||
if not asp_function.args:
|
||||
raise ValueError(f"Can't shift ASP function with no arguments: {str(asp_function)}")
|
||||
first, *rest = asp_function.args
|
||||
return asp.AspFunction(first, rest)
|
||||
|
||||
|
||||
def compare_specs(a, b, to_string=False, color=None):
|
||||
"""
|
||||
Generate a comparison, including diffs (for each side) and an intersection.
|
||||
@ -79,7 +71,7 @@ def compare_specs(a, b, to_string=False, color=None):
|
||||
# get facts for specs, making sure to include build dependencies of concrete
|
||||
# specs and to descend into dependency hashes so we include all facts.
|
||||
a_facts = set(
|
||||
shift(func)
|
||||
func.shift()
|
||||
for func in setup.spec_clauses(
|
||||
a,
|
||||
body=True,
|
||||
@ -89,7 +81,7 @@ def compare_specs(a, b, to_string=False, color=None):
|
||||
if func.name == "attr"
|
||||
)
|
||||
b_facts = set(
|
||||
shift(func)
|
||||
func.shift()
|
||||
for func in setup.spec_clauses(
|
||||
b,
|
||||
body=True,
|
||||
|
@ -159,7 +159,7 @@ def build_criteria_names(costs, opt_criteria):
|
||||
"""Construct an ordered mapping from criteria names to costs."""
|
||||
|
||||
# ensure names of all criteria are unique
|
||||
names = {name for _, name in opt_criteria}
|
||||
names = {criterion.args[0] for criterion in opt_criteria}
|
||||
assert len(names) == len(opt_criteria), "names of optimization criteria must be unique"
|
||||
|
||||
# costs contains:
|
||||
@ -186,7 +186,7 @@ def build_criteria_names(costs, opt_criteria):
|
||||
|
||||
# list of build cost, reuse cost, and name of each criterion
|
||||
criteria: List[Tuple[int, int, str]] = []
|
||||
for i, (priority, name) in enumerate(opt_criteria):
|
||||
for i, (priority, name) in enumerate(c.args for c in opt_criteria):
|
||||
priority = int(priority)
|
||||
build_cost = ordered_costs[i]
|
||||
reuse_cost = ordered_costs[i + n_build_criteria] if priority < 100_000 else None
|
||||
@ -279,6 +279,25 @@ def argify(arg):
|
||||
|
||||
return clingo.Function(self.name, [argify(arg) for arg in self.args], positive=positive)
|
||||
|
||||
@staticmethod
|
||||
def from_symbol(symbol):
|
||||
def deargify(arg):
|
||||
if arg.type is clingo.SymbolType.Number:
|
||||
return arg.number
|
||||
elif arg.type is clingo.SymbolType.String and arg.string in ("True", "False"):
|
||||
return arg.string == "True"
|
||||
else:
|
||||
return arg.string
|
||||
|
||||
return AspFunction(symbol.name, [deargify(arg) for arg in symbol.arguments])
|
||||
|
||||
def shift(self):
|
||||
"""Transforms ``attr("foo", "bar")`` into ``foo("bar")``."""
|
||||
if not self.args:
|
||||
raise ValueError(f"Can't shift ASP function with no arguments: {str(self)}")
|
||||
first, *rest = self.args
|
||||
return AspFunction(first, rest)
|
||||
|
||||
def __str__(self):
|
||||
return "%s(%s)" % (self.name, ", ".join(str(_id(arg)) for arg in self.args))
|
||||
|
||||
@ -565,13 +584,13 @@ def stringify(sym):
|
||||
return sym.string or str(sym)
|
||||
|
||||
|
||||
def extract_args(model, predicate_name):
|
||||
"""Extract the arguments to predicates with the provided name from a model.
|
||||
def extract_functions(model, function_name):
|
||||
"""Extract ASP functions with the given name from a model.
|
||||
|
||||
Pull out all the predicates with name ``predicate_name`` from the model, and return
|
||||
their stringified arguments as tuples.
|
||||
Pull out all the functions with name ``function_name`` from the model, and return them as
|
||||
``AspFunction`` objects.
|
||||
"""
|
||||
return [stringify(sym.arguments) for sym in model if sym.name == predicate_name]
|
||||
return [AspFunction.from_symbol(sym) for sym in model if sym.name == function_name]
|
||||
|
||||
|
||||
class PyclingoDriver(object):
|
||||
@ -744,21 +763,21 @@ def on_model(model):
|
||||
min_cost, best_model = min(models)
|
||||
|
||||
# first check for errors
|
||||
error_args = extract_args(best_model, "error")
|
||||
error_args = [fn.args for fn in extract_functions(best_model, "error")]
|
||||
errors = sorted((int(priority), msg, args) for priority, msg, *args in error_args)
|
||||
for _, msg, args in errors:
|
||||
self.handle_error(msg, *args)
|
||||
|
||||
# build specs from spec attributes in the model
|
||||
spec_attrs = [(name, tuple(rest)) for name, *rest in extract_args(best_model, "attr")]
|
||||
spec_attrs = extract_functions(best_model, "attr")
|
||||
answers = builder.build_specs(spec_attrs)
|
||||
|
||||
# add best spec to the results
|
||||
result.answers.append((list(min_cost), 0, answers))
|
||||
|
||||
# get optimization criteria
|
||||
criteria_args = extract_args(best_model, "opt_criterion")
|
||||
result.criteria = build_criteria_names(min_cost, criteria_args)
|
||||
criteria = extract_functions(best_model, "opt_criterion")
|
||||
result.criteria = build_criteria_names(min_cost, criteria)
|
||||
|
||||
# record the number of models the solver considered
|
||||
result.nmodels = len(models)
|
||||
@ -2305,7 +2324,7 @@ def deprecated(self, pkg, version):
|
||||
tty.warn(msg.format(pkg, version))
|
||||
|
||||
@staticmethod
|
||||
def sort_fn(function_tuple):
|
||||
def sort_fn(function):
|
||||
"""Ensure attributes are evaluated in the correct order.
|
||||
|
||||
hash attributes are handled first, since they imply entire concrete specs
|
||||
@ -2315,7 +2334,7 @@ def sort_fn(function_tuple):
|
||||
the concrete specs on which they depend because all nodes are fully constructed before we
|
||||
consider which ones are external.
|
||||
"""
|
||||
name = function_tuple[0]
|
||||
name = function.args[0]
|
||||
if name == "hash":
|
||||
return (-5, 0)
|
||||
elif name == "node":
|
||||
@ -2329,23 +2348,24 @@ def sort_fn(function_tuple):
|
||||
else:
|
||||
return (-1, 0)
|
||||
|
||||
def build_specs(self, function_tuples):
|
||||
def build_specs(self, functions):
|
||||
# Functions don't seem to be in particular order in output. Sort
|
||||
# them here so that directives that build objects (like node and
|
||||
# node_compiler) are called in the right order.
|
||||
self.function_tuples = sorted(set(function_tuples), key=self.sort_fn)
|
||||
self.functions = sorted(set(functions), key=self.sort_fn)
|
||||
|
||||
self._specs = {}
|
||||
for name, args in self.function_tuples:
|
||||
if SpecBuilder.ignored_attributes.match(name):
|
||||
for attr in self.functions:
|
||||
fn = attr.shift() # attr("foo", "bar") -> foo("bar")
|
||||
|
||||
if SpecBuilder.ignored_attributes.match(fn.name):
|
||||
continue
|
||||
|
||||
action = getattr(self, name, None)
|
||||
action = getattr(self, fn.name, None)
|
||||
|
||||
# print out unknown actions so we can display them for debugging
|
||||
if not action:
|
||||
msg = 'UNKNOWN SYMBOL: attr("%s", %s)' % (name, ", ".join(str(a) for a in args))
|
||||
tty.debug(msg)
|
||||
tty.debug(f"UNKNOWN SYMBOL: {attr}")
|
||||
continue
|
||||
|
||||
msg = (
|
||||
@ -2357,8 +2377,8 @@ def build_specs(self, function_tuples):
|
||||
# ignore predicates on virtual packages, as they're used for
|
||||
# solving but don't construct anything. Do not ignore error
|
||||
# predicates on virtual packages.
|
||||
if name != "error":
|
||||
pkg = args[0]
|
||||
if fn.name != "error":
|
||||
pkg = fn.args[0]
|
||||
if spack.repo.path.is_virtual(pkg):
|
||||
continue
|
||||
|
||||
@ -2368,7 +2388,7 @@ def build_specs(self, function_tuples):
|
||||
if spec and spec.concrete:
|
||||
continue
|
||||
|
||||
action(*args)
|
||||
action(*fn.args)
|
||||
|
||||
# namespace assignment is done after the fact, as it is not
|
||||
# currently part of the solve
|
||||
|
Loading…
Reference in New Issue
Block a user