ASP-based solver: decouple setup phase from clingo.backend (#41952)
				
					
				
			Currently, the `SpackSolverSetup` and the `PyclingoDriver` are more coupled than necessary: 1. The driver object needs a setup object to be injected during a solve, 2. And the setup object will get a reference back to the driver This design is necessary because we use the low-level `clingo.backend` interface to setup our problem. This interface though is meant to bypass the grounder and add symbols directly in the grounded table, which is a feature we don't currently use. The PR simplifies the encoding by having the setup object returning the problem-specific facts / rules as a list of strings, and the driver ingesting them using the [clingo.Control.add](https://potassco.org/clingo/python-api/5.6/clingo/control.html#clingo.control.Control.add) method. This removes any use of the low level interface. Using this encoding makes it easy to hash the output of the setup phase, since it is returned as a string.
This commit is contained in:
		
				
					committed by
					
						
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							97fb9565ee
						
					
				
				
					commit
					5c49bb45c7
				
			@@ -96,6 +96,7 @@
 | 
			
		||||
# these are from clingo.ast and bootstrapped later
 | 
			
		||||
ASTType = None
 | 
			
		||||
parse_files = None
 | 
			
		||||
parse_term = None
 | 
			
		||||
 | 
			
		||||
#: Enable the addition of a runtime node
 | 
			
		||||
WITH_RUNTIME = sys.platform != "win32"
 | 
			
		||||
@@ -310,11 +311,11 @@ def _id(thing):
 | 
			
		||||
    if isinstance(thing, AspObject):
 | 
			
		||||
        return thing
 | 
			
		||||
    elif isinstance(thing, bool):
 | 
			
		||||
        return '"%s"' % str(thing)
 | 
			
		||||
        return f'"{str(thing)}"'
 | 
			
		||||
    elif isinstance(thing, int):
 | 
			
		||||
        return str(thing)
 | 
			
		||||
    else:
 | 
			
		||||
        return '"%s"' % str(thing)
 | 
			
		||||
        return f'"{str(thing)}"'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@llnl.util.lang.key_ordering
 | 
			
		||||
@@ -351,21 +352,20 @@ def __call__(self, *args):
 | 
			
		||||
        """
 | 
			
		||||
        return AspFunction(self.name, self.args + args)
 | 
			
		||||
 | 
			
		||||
    def symbol(self, positive=True):
 | 
			
		||||
        def argify(arg):
 | 
			
		||||
            if isinstance(arg, bool):
 | 
			
		||||
                return clingo.String(str(arg))
 | 
			
		||||
            elif isinstance(arg, int):
 | 
			
		||||
                return clingo.Number(arg)
 | 
			
		||||
            elif isinstance(arg, AspFunction):
 | 
			
		||||
                return clingo.Function(arg.name, [argify(x) for x in arg.args], positive=positive)
 | 
			
		||||
            else:
 | 
			
		||||
                return clingo.String(str(arg))
 | 
			
		||||
    def argify(self, arg):
 | 
			
		||||
        if isinstance(arg, bool):
 | 
			
		||||
            return clingo.String(str(arg))
 | 
			
		||||
        elif isinstance(arg, int):
 | 
			
		||||
            return clingo.Number(arg)
 | 
			
		||||
        elif isinstance(arg, AspFunction):
 | 
			
		||||
            return clingo.Function(arg.name, [self.argify(x) for x in arg.args], positive=True)
 | 
			
		||||
        return clingo.String(str(arg))
 | 
			
		||||
 | 
			
		||||
        return clingo.Function(self.name, [argify(arg) for arg in self.args], positive=positive)
 | 
			
		||||
    def symbol(self):
 | 
			
		||||
        return clingo.Function(self.name, [self.argify(arg) for arg in self.args], positive=True)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return "%s(%s)" % (self.name, ", ".join(str(_id(arg)) for arg in self.args))
 | 
			
		||||
        return f"{self.name}({', '.join(str(_id(arg)) for arg in self.args)})"
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return str(self)
 | 
			
		||||
@@ -664,7 +664,7 @@ def _spec_with_default_name(spec_str, name):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bootstrap_clingo():
 | 
			
		||||
    global clingo, ASTType, parse_files
 | 
			
		||||
    global clingo, ASTType, parse_files, parse_term
 | 
			
		||||
 | 
			
		||||
    if not clingo:
 | 
			
		||||
        import spack.bootstrap
 | 
			
		||||
@@ -677,9 +677,10 @@ def bootstrap_clingo():
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        from clingo.ast import parse_files
 | 
			
		||||
        from clingo.symbol import parse_term
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        # older versions of clingo have this one namespace up
 | 
			
		||||
        from clingo import parse_files
 | 
			
		||||
        from clingo import parse_files, parse_term
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NodeArgument(NamedTuple):
 | 
			
		||||
@@ -882,53 +883,9 @@ def __init__(self, cores=True):
 | 
			
		||||
                error reporting.
 | 
			
		||||
        """
 | 
			
		||||
        bootstrap_clingo()
 | 
			
		||||
 | 
			
		||||
        self.out = llnl.util.lang.Devnull()
 | 
			
		||||
        self.cores = cores
 | 
			
		||||
 | 
			
		||||
        # These attributes are part of the object, but will be reset
 | 
			
		||||
        # at each call to solve
 | 
			
		||||
        # This attribute will be reset at each call to solve
 | 
			
		||||
        self.control = None
 | 
			
		||||
        self.backend = None
 | 
			
		||||
        self.assumptions = None
 | 
			
		||||
 | 
			
		||||
    def title(self, name, char):
 | 
			
		||||
        self.out.write("\n")
 | 
			
		||||
        self.out.write("%" + (char * 76))
 | 
			
		||||
        self.out.write("\n")
 | 
			
		||||
        self.out.write("%% %s\n" % name)
 | 
			
		||||
        self.out.write("%" + (char * 76))
 | 
			
		||||
        self.out.write("\n")
 | 
			
		||||
 | 
			
		||||
    def h1(self, name):
 | 
			
		||||
        self.title(name, "=")
 | 
			
		||||
 | 
			
		||||
    def h2(self, name):
 | 
			
		||||
        self.title(name, "-")
 | 
			
		||||
 | 
			
		||||
    def newline(self):
 | 
			
		||||
        self.out.write("\n")
 | 
			
		||||
 | 
			
		||||
    def fact(self, head):
 | 
			
		||||
        """ASP fact (a rule without a body).
 | 
			
		||||
 | 
			
		||||
        Arguments:
 | 
			
		||||
            head (AspFunction): ASP function to generate as fact
 | 
			
		||||
        """
 | 
			
		||||
        symbol = head.symbol() if hasattr(head, "symbol") else head
 | 
			
		||||
 | 
			
		||||
        # This is commented out to avoid evaluating str(symbol) when we have no stream
 | 
			
		||||
        if not isinstance(self.out, llnl.util.lang.Devnull):
 | 
			
		||||
            self.out.write(f"{str(symbol)}.\n")
 | 
			
		||||
 | 
			
		||||
        atom = self.backend.add_atom(symbol)
 | 
			
		||||
 | 
			
		||||
        # Only functions relevant for constructing bug reports for bad error messages
 | 
			
		||||
        # are assumptions, and only when using cores.
 | 
			
		||||
        choice = self.cores and symbol.name == "internal_error"
 | 
			
		||||
        self.backend.add_rule([atom], [], choice=choice)
 | 
			
		||||
        if choice:
 | 
			
		||||
            self.assumptions.append(atom)
 | 
			
		||||
 | 
			
		||||
    def solve(self, setup, specs, reuse=None, output=None, control=None, allow_deprecated=False):
 | 
			
		||||
        """Set up the input and solve for dependencies of ``specs``.
 | 
			
		||||
@@ -948,49 +905,24 @@ def solve(self, setup, specs, reuse=None, output=None, control=None, allow_depre
 | 
			
		||||
            solve, and the internal statistics from clingo.
 | 
			
		||||
        """
 | 
			
		||||
        output = output or DEFAULT_OUTPUT_CONFIGURATION
 | 
			
		||||
        # allow solve method to override the output stream
 | 
			
		||||
        if output.out is not None:
 | 
			
		||||
            self.out = output.out
 | 
			
		||||
 | 
			
		||||
        timer = spack.util.timer.Timer()
 | 
			
		||||
 | 
			
		||||
        # Initialize the control object for the solver
 | 
			
		||||
        self.control = control or default_clingo_control()
 | 
			
		||||
        # set up the problem -- this generates facts and rules
 | 
			
		||||
        self.assumptions = []
 | 
			
		||||
 | 
			
		||||
        timer.start("setup")
 | 
			
		||||
        with self.control.backend() as backend:
 | 
			
		||||
            self.backend = backend
 | 
			
		||||
            setup.setup(self, specs, reuse=reuse, allow_deprecated=allow_deprecated)
 | 
			
		||||
        asp_problem = setup.setup(specs, reuse=reuse, allow_deprecated=allow_deprecated)
 | 
			
		||||
        if output.out is not None:
 | 
			
		||||
            output.out.write(asp_problem)
 | 
			
		||||
        if output.setup_only:
 | 
			
		||||
            return Result(specs), None, None
 | 
			
		||||
        timer.stop("setup")
 | 
			
		||||
 | 
			
		||||
        timer.start("load")
 | 
			
		||||
        # read in the main ASP program and display logic -- these are
 | 
			
		||||
        # handwritten, not generated, so we load them as resources
 | 
			
		||||
        parent_dir = os.path.dirname(__file__)
 | 
			
		||||
 | 
			
		||||
        # extract error messages from concretize.lp by inspecting its AST
 | 
			
		||||
        with self.backend:
 | 
			
		||||
 | 
			
		||||
            def visit(node):
 | 
			
		||||
                if ast_type(node) == ASTType.Rule:
 | 
			
		||||
                    for term in node.body:
 | 
			
		||||
                        if ast_type(term) == ASTType.Literal:
 | 
			
		||||
                            if ast_type(term.atom) == ASTType.SymbolicAtom:
 | 
			
		||||
                                name = ast_sym(term.atom).name
 | 
			
		||||
                                if name == "internal_error":
 | 
			
		||||
                                    arg = ast_sym(ast_sym(term.atom).arguments[0])
 | 
			
		||||
                                    self.fact(AspFunction(name)(arg.string))
 | 
			
		||||
 | 
			
		||||
            self.h1("Error messages")
 | 
			
		||||
            path = os.path.join(parent_dir, "concretize.lp")
 | 
			
		||||
            parse_files([path], visit)
 | 
			
		||||
 | 
			
		||||
        # If we're only doing setup, just return an empty solve result
 | 
			
		||||
        if output.setup_only:
 | 
			
		||||
            return Result(specs), None, None
 | 
			
		||||
 | 
			
		||||
        # Add the problem instance
 | 
			
		||||
        self.control.add("base", [], asp_problem)
 | 
			
		||||
        # Load the file itself
 | 
			
		||||
        parent_dir = os.path.dirname(__file__)
 | 
			
		||||
        self.control.load(os.path.join(parent_dir, "concretize.lp"))
 | 
			
		||||
        self.control.load(os.path.join(parent_dir, "heuristic.lp"))
 | 
			
		||||
        if spack.config.CONFIG.get("concretizer:duplicates:strategy", "none") != "none":
 | 
			
		||||
@@ -1016,7 +948,7 @@ def on_model(model):
 | 
			
		||||
            models.append((model.cost, model.symbols(shown=True, terms=True)))
 | 
			
		||||
 | 
			
		||||
        solve_kwargs = {
 | 
			
		||||
            "assumptions": self.assumptions,
 | 
			
		||||
            "assumptions": setup.assumptions,
 | 
			
		||||
            "on_model": on_model,
 | 
			
		||||
            "on_core": cores.append,
 | 
			
		||||
        }
 | 
			
		||||
@@ -1142,6 +1074,7 @@ class SpackSolverSetup:
 | 
			
		||||
    def __init__(self, tests=False):
 | 
			
		||||
        self.gen = None  # set by setup()
 | 
			
		||||
 | 
			
		||||
        self.assumptions = []
 | 
			
		||||
        self.declared_versions = collections.defaultdict(list)
 | 
			
		||||
        self.possible_versions = collections.defaultdict(set)
 | 
			
		||||
        self.deprecated_versions = collections.defaultdict(set)
 | 
			
		||||
@@ -1878,36 +1811,7 @@ def _spec_clauses(
 | 
			
		||||
        """
 | 
			
		||||
        clauses = []
 | 
			
		||||
 | 
			
		||||
        # TODO: do this with consistent suffixes.
 | 
			
		||||
        class Head:
 | 
			
		||||
            node = fn.attr("node")
 | 
			
		||||
            virtual_node = fn.attr("virtual_node")
 | 
			
		||||
            node_platform = fn.attr("node_platform_set")
 | 
			
		||||
            node_os = fn.attr("node_os_set")
 | 
			
		||||
            node_target = fn.attr("node_target_set")
 | 
			
		||||
            variant_value = fn.attr("variant_set")
 | 
			
		||||
            node_compiler = fn.attr("node_compiler_set")
 | 
			
		||||
            node_compiler_version = fn.attr("node_compiler_version_set")
 | 
			
		||||
            node_flag = fn.attr("node_flag_set")
 | 
			
		||||
            node_flag_source = fn.attr("node_flag_source")
 | 
			
		||||
            node_flag_propagate = fn.attr("node_flag_propagate")
 | 
			
		||||
            variant_propagation_candidate = fn.attr("variant_propagation_candidate")
 | 
			
		||||
 | 
			
		||||
        class Body:
 | 
			
		||||
            node = fn.attr("node")
 | 
			
		||||
            virtual_node = fn.attr("virtual_node")
 | 
			
		||||
            node_platform = fn.attr("node_platform")
 | 
			
		||||
            node_os = fn.attr("node_os")
 | 
			
		||||
            node_target = fn.attr("node_target")
 | 
			
		||||
            variant_value = fn.attr("variant_value")
 | 
			
		||||
            node_compiler = fn.attr("node_compiler")
 | 
			
		||||
            node_compiler_version = fn.attr("node_compiler_version")
 | 
			
		||||
            node_flag = fn.attr("node_flag")
 | 
			
		||||
            node_flag_source = fn.attr("node_flag_source")
 | 
			
		||||
            node_flag_propagate = fn.attr("node_flag_propagate")
 | 
			
		||||
            variant_propagation_candidate = fn.attr("variant_propagation_candidate")
 | 
			
		||||
 | 
			
		||||
        f = Body if body else Head
 | 
			
		||||
        f = _Body if body else _Head
 | 
			
		||||
 | 
			
		||||
        if spec.name:
 | 
			
		||||
            clauses.append(f.node(spec.name) if not spec.virtual else f.virtual_node(spec.name))
 | 
			
		||||
@@ -2503,12 +2407,11 @@ def define_concrete_input_specs(self, specs, possible):
 | 
			
		||||
 | 
			
		||||
    def setup(
 | 
			
		||||
        self,
 | 
			
		||||
        driver: PyclingoDriver,
 | 
			
		||||
        specs: Sequence[spack.spec.Spec],
 | 
			
		||||
        *,
 | 
			
		||||
        reuse: Optional[List[spack.spec.Spec]] = None,
 | 
			
		||||
        allow_deprecated: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        """Generate an ASP program with relevant constraints for specs.
 | 
			
		||||
 | 
			
		||||
        This calls methods on the solve driver to set up the problem with
 | 
			
		||||
@@ -2516,7 +2419,6 @@ def setup(
 | 
			
		||||
        specs, as well as constraints from the specs themselves.
 | 
			
		||||
 | 
			
		||||
        Arguments:
 | 
			
		||||
            driver: driver instance of this solve
 | 
			
		||||
            specs: list of Specs to solve
 | 
			
		||||
            reuse: list of concrete specs that can be reused
 | 
			
		||||
            allow_deprecated: if True adds deprecated versions into the solve
 | 
			
		||||
@@ -2542,9 +2444,7 @@ def setup(
 | 
			
		||||
            if node.namespace is not None:
 | 
			
		||||
                self.explicitly_required_namespaces[node.name] = node.namespace
 | 
			
		||||
 | 
			
		||||
        # driver is used by all the functions below to add facts and
 | 
			
		||||
        # rules to generate an ASP program.
 | 
			
		||||
        self.gen = driver
 | 
			
		||||
        self.gen = ProblemInstanceBuilder()
 | 
			
		||||
 | 
			
		||||
        if not allow_deprecated:
 | 
			
		||||
            self.gen.fact(fn.deprecated_versions_not_allowed())
 | 
			
		||||
@@ -2648,6 +2548,29 @@ def setup(
 | 
			
		||||
        self.gen.h1("Target Constraints")
 | 
			
		||||
        self.define_target_constraints()
 | 
			
		||||
 | 
			
		||||
        self.gen.h1("Internal errors")
 | 
			
		||||
        self.internal_errors()
 | 
			
		||||
 | 
			
		||||
        return self.gen.value()
 | 
			
		||||
 | 
			
		||||
    def internal_errors(self):
 | 
			
		||||
        parent_dir = os.path.dirname(__file__)
 | 
			
		||||
 | 
			
		||||
        def visit(node):
 | 
			
		||||
            if ast_type(node) == ASTType.Rule:
 | 
			
		||||
                for term in node.body:
 | 
			
		||||
                    if ast_type(term) == ASTType.Literal:
 | 
			
		||||
                        if ast_type(term.atom) == ASTType.SymbolicAtom:
 | 
			
		||||
                            name = ast_sym(term.atom).name
 | 
			
		||||
                            if name == "internal_error":
 | 
			
		||||
                                arg = ast_sym(ast_sym(term.atom).arguments[0])
 | 
			
		||||
                                symbol = AspFunction(name)(arg.string)
 | 
			
		||||
                                self.assumptions.append((parse_term(str(symbol)), True))
 | 
			
		||||
                                self.gen.asp_problem.append(f"{{ {symbol} }}.\n")
 | 
			
		||||
 | 
			
		||||
        path = os.path.join(parent_dir, "concretize.lp")
 | 
			
		||||
        parse_files([path], visit)
 | 
			
		||||
 | 
			
		||||
    def define_runtime_constraints(self):
 | 
			
		||||
        """Define the constraints to be imposed on the runtimes"""
 | 
			
		||||
        recorder = RuntimePropertyRecorder(self)
 | 
			
		||||
@@ -2778,6 +2701,83 @@ def pkg_class(self, pkg_name: str) -> typing.Type["spack.package_base.PackageBas
 | 
			
		||||
        return spack.repo.PATH.get_pkg_class(request)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _Head:
 | 
			
		||||
    """ASP functions used to express spec clauses in the HEAD of a rule"""
 | 
			
		||||
 | 
			
		||||
    node = fn.attr("node")
 | 
			
		||||
    virtual_node = fn.attr("virtual_node")
 | 
			
		||||
    node_platform = fn.attr("node_platform_set")
 | 
			
		||||
    node_os = fn.attr("node_os_set")
 | 
			
		||||
    node_target = fn.attr("node_target_set")
 | 
			
		||||
    variant_value = fn.attr("variant_set")
 | 
			
		||||
    node_compiler = fn.attr("node_compiler_set")
 | 
			
		||||
    node_compiler_version = fn.attr("node_compiler_version_set")
 | 
			
		||||
    node_flag = fn.attr("node_flag_set")
 | 
			
		||||
    node_flag_source = fn.attr("node_flag_source")
 | 
			
		||||
    node_flag_propagate = fn.attr("node_flag_propagate")
 | 
			
		||||
    variant_propagation_candidate = fn.attr("variant_propagation_candidate")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _Body:
 | 
			
		||||
    """ASP functions used to express spec clauses in the BODY of a rule"""
 | 
			
		||||
 | 
			
		||||
    node = fn.attr("node")
 | 
			
		||||
    virtual_node = fn.attr("virtual_node")
 | 
			
		||||
    node_platform = fn.attr("node_platform")
 | 
			
		||||
    node_os = fn.attr("node_os")
 | 
			
		||||
    node_target = fn.attr("node_target")
 | 
			
		||||
    variant_value = fn.attr("variant_value")
 | 
			
		||||
    node_compiler = fn.attr("node_compiler")
 | 
			
		||||
    node_compiler_version = fn.attr("node_compiler_version")
 | 
			
		||||
    node_flag = fn.attr("node_flag")
 | 
			
		||||
    node_flag_source = fn.attr("node_flag_source")
 | 
			
		||||
    node_flag_propagate = fn.attr("node_flag_propagate")
 | 
			
		||||
    variant_propagation_candidate = fn.attr("variant_propagation_candidate")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProblemInstanceBuilder:
 | 
			
		||||
    """Provides an interface to construct a problem instance.
 | 
			
		||||
 | 
			
		||||
    Once all the facts and rules have been added, the problem instance can be retrieved with:
 | 
			
		||||
 | 
			
		||||
    >>> builder = ProblemInstanceBuilder()
 | 
			
		||||
    >>> ...
 | 
			
		||||
    >>> problem_instance = builder.value()
 | 
			
		||||
 | 
			
		||||
    The problem instance can be added directly to the "control" structure of clingo.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.asp_problem = []
 | 
			
		||||
 | 
			
		||||
    def fact(self, atom: AspFunction) -> None:
 | 
			
		||||
        symbol = atom.symbol() if hasattr(atom, "symbol") else atom
 | 
			
		||||
        self.asp_problem.append(f"{str(symbol)}.\n")
 | 
			
		||||
 | 
			
		||||
    def append(self, rule: str) -> None:
 | 
			
		||||
        self.asp_problem.append(rule)
 | 
			
		||||
 | 
			
		||||
    def title(self, header: str, char: str) -> None:
 | 
			
		||||
        self.asp_problem.append("\n")
 | 
			
		||||
        self.asp_problem.append("%" + (char * 76))
 | 
			
		||||
        self.asp_problem.append("\n")
 | 
			
		||||
        self.asp_problem.append(f"% {header}\n")
 | 
			
		||||
        self.asp_problem.append("%" + (char * 76))
 | 
			
		||||
        self.asp_problem.append("\n")
 | 
			
		||||
 | 
			
		||||
    def h1(self, header: str) -> None:
 | 
			
		||||
        self.title(header, "=")
 | 
			
		||||
 | 
			
		||||
    def h2(self, header: str) -> None:
 | 
			
		||||
        self.title(header, "-")
 | 
			
		||||
 | 
			
		||||
    def newline(self):
 | 
			
		||||
        self.asp_problem.append("\n")
 | 
			
		||||
 | 
			
		||||
    def value(self) -> str:
 | 
			
		||||
        return "".join(self.asp_problem)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RequirementParser:
 | 
			
		||||
    """Parses requirements from package.py files and configuration, and returns rules."""
 | 
			
		||||
 | 
			
		||||
@@ -3085,9 +3085,7 @@ def consume_facts(self):
 | 
			
		||||
        self._setup.gen.h2("Runtimes: rules")
 | 
			
		||||
        self._setup.gen.newline()
 | 
			
		||||
        for rule in self.rules:
 | 
			
		||||
            if not isinstance(self._setup.gen.out, llnl.util.lang.Devnull):
 | 
			
		||||
                self._setup.gen.out.write(rule)
 | 
			
		||||
            self._setup.gen.control.add("base", [], rule)
 | 
			
		||||
            self._setup.gen.append(rule)
 | 
			
		||||
 | 
			
		||||
        self._setup.gen.h2("Runtimes: conditions")
 | 
			
		||||
        for runtime_pkg in spack.repo.PATH.packages_with_tags("runtime"):
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user