Rework the encoding to introduce node(ID, Package) nested facts

So far the encoding has a single ID per package, i.e. all the
facts will be node(0, Package). This will prepare the stage for
extending this logic and having multiple nodes from the same
package in a DAG.
This commit is contained in:
Massimiliano Culpo 2023-06-16 15:25:30 +02:00 committed by Todd Gamblin
parent 907a80ca71
commit 27ab53b68a
3 changed files with 631 additions and 539 deletions

View File

@ -515,15 +515,17 @@ def _compute_specs_from_answer_set(self):
best = min(self.answers) best = min(self.answers)
opt, _, answer = best opt, _, answer = best
for input_spec in self.abstract_specs: for input_spec in self.abstract_specs:
key = input_spec.name node = SpecBuilder.root_node(pkg=input_spec.name)
if input_spec.virtual: if input_spec.virtual:
providers = [spec.name for spec in answer.values() if spec.package.provides(key)] providers = [
key = providers[0] spec.name for spec in answer.values() if spec.package.provides(input_spec.name)
candidate = answer.get(key) ]
node = SpecBuilder.root_node(pkg=providers[0])
candidate = answer.get(node)
if candidate and candidate.satisfies(input_spec): if candidate and candidate.satisfies(input_spec):
self._concrete_specs.append(answer[key]) self._concrete_specs.append(answer[node])
self._concrete_specs_by_input[input_spec] = answer[key] self._concrete_specs_by_input[input_spec] = answer[node]
else: else:
self._unsolved_specs.append(input_spec) self._unsolved_specs.append(input_spec)
@ -2426,6 +2428,18 @@ class SpecBuilder:
) )
) )
node_regex = re.compile(r"node\(\d,\"(.*)\"\)")
@staticmethod
def root_node(*, pkg: str) -> str:
"""Given a package name, returns the string representation of the root node in
the ASP encoding.
Args:
pkg: name of a package
"""
return f'node(0,"{pkg}")'
def __init__(self, specs, hash_lookup=None): def __init__(self, specs, hash_lookup=None):
self._specs = {} self._specs = {}
self._result = None self._result = None
@ -2438,100 +2452,121 @@ def __init__(self, specs, hash_lookup=None):
# from this dictionary during reconstruction # from this dictionary during reconstruction
self._hash_lookup = hash_lookup or {} self._hash_lookup = hash_lookup or {}
def hash(self, pkg, h): @staticmethod
if pkg not in self._specs: def extract_pkg(node: str) -> str:
self._specs[pkg] = self._hash_lookup[h] """Extracts the package name from a node fact, and returns it.
self._hash_specs.append(pkg)
def node(self, pkg): Args:
if pkg not in self._specs: node: node from which the package name is to be extracted
self._specs[pkg] = spack.spec.Spec(pkg) """
m = SpecBuilder.node_regex.match(node)
if m is None:
raise spack.error.SpackError(f"cannot extract package information from '{node}'")
def _arch(self, pkg): return m.group(1)
arch = self._specs[pkg].architecture
def hash(self, node, h):
if node not in self._specs:
self._specs[node] = self._hash_lookup[h]
self._hash_specs.append(node)
def node(self, node):
pkg = self.extract_pkg(node)
if node not in self._specs:
self._specs[node] = spack.spec.Spec(pkg)
def _arch(self, node):
arch = self._specs[node].architecture
if not arch: if not arch:
arch = spack.spec.ArchSpec() arch = spack.spec.ArchSpec()
self._specs[pkg].architecture = arch self._specs[node].architecture = arch
return arch return arch
def node_platform(self, pkg, platform): def node_platform(self, node, platform):
self._arch(pkg).platform = platform self._arch(node).platform = platform
def node_os(self, pkg, os): def node_os(self, node, os):
self._arch(pkg).os = os self._arch(node).os = os
def node_target(self, pkg, target): def node_target(self, node, target):
self._arch(pkg).target = target self._arch(node).target = target
def variant_value(self, pkg, name, value): def variant_value(self, node, name, value):
# FIXME: is there a way not to special case 'dev_path' everywhere? # FIXME: is there a way not to special case 'dev_path' everywhere?
if name == "dev_path": if name == "dev_path":
self._specs[pkg].variants.setdefault( self._specs[node].variants.setdefault(
name, spack.variant.SingleValuedVariant(name, value) name, spack.variant.SingleValuedVariant(name, value)
) )
return return
if name == "patches": if name == "patches":
self._specs[pkg].variants.setdefault( self._specs[node].variants.setdefault(
name, spack.variant.MultiValuedVariant(name, value) name, spack.variant.MultiValuedVariant(name, value)
) )
return return
self._specs[pkg].update_variant_validate(name, value) self._specs[node].update_variant_validate(name, value)
def version(self, pkg, version): def version(self, node, version):
self._specs[pkg].versions = vn.VersionList([vn.Version(version)]) self._specs[node].versions = vn.VersionList([vn.Version(version)])
def node_compiler_version(self, pkg, compiler, version): def node_compiler_version(self, node, compiler, version):
self._specs[pkg].compiler = spack.spec.CompilerSpec(compiler) self._specs[node].compiler = spack.spec.CompilerSpec(compiler)
self._specs[pkg].compiler.versions = vn.VersionList([vn.Version(version)]) self._specs[node].compiler.versions = vn.VersionList([vn.Version(version)])
def node_flag_compiler_default(self, pkg): def node_flag_compiler_default(self, node):
self._flag_compiler_defaults.add(pkg) self._flag_compiler_defaults.add(node)
def node_flag(self, pkg, flag_type, flag): def node_flag(self, node, flag_type, flag):
self._specs[pkg].compiler_flags.add_flag(flag_type, flag, False) self._specs[node].compiler_flags.add_flag(flag_type, flag, False)
def node_flag_source(self, pkg, flag_type, source): def node_flag_source(self, node, flag_type, source):
self._flag_sources[(pkg, flag_type)].add(source) self._flag_sources[(node, flag_type)].add(source)
def no_flags(self, pkg, flag_type): def no_flags(self, node, flag_type):
self._specs[pkg].compiler_flags[flag_type] = [] self._specs[node].compiler_flags[flag_type] = []
def external_spec_selected(self, pkg, idx): def external_spec_selected(self, node, idx):
"""This means that the external spec and index idx """This means that the external spec and index idx
has been selected for this package. has been selected for this package.
""" """
packages_yaml = spack.config.get("packages") packages_yaml = spack.config.get("packages")
packages_yaml = _normalize_packages_yaml(packages_yaml) packages_yaml = _normalize_packages_yaml(packages_yaml)
pkg = self.extract_pkg(node)
spec_info = packages_yaml[pkg]["externals"][int(idx)] spec_info = packages_yaml[pkg]["externals"][int(idx)]
self._specs[pkg].external_path = spec_info.get("prefix", None) self._specs[node].external_path = spec_info.get("prefix", None)
self._specs[pkg].external_modules = spack.spec.Spec._format_module_list( self._specs[node].external_modules = spack.spec.Spec._format_module_list(
spec_info.get("modules", None) spec_info.get("modules", None)
) )
self._specs[pkg].extra_attributes = spec_info.get("extra_attributes", {}) self._specs[node].extra_attributes = spec_info.get("extra_attributes", {})
# If this is an extension, update the dependencies to include the extendee # If this is an extension, update the dependencies to include the extendee
package = self._specs[pkg].package_class(self._specs[pkg]) package = self._specs[node].package_class(self._specs[node])
extendee_spec = package.extendee_spec extendee_spec = package.extendee_spec
if extendee_spec:
package.update_external_dependencies(self._specs.get(extendee_spec.name, None))
def depends_on(self, pkg, dep, type): if extendee_spec:
dependencies = self._specs[pkg].edges_to_dependencies(name=dep) extendee_node = SpecBuilder.root_node(pkg=extendee_spec.name)
package.update_external_dependencies(self._specs.get(extendee_node, None))
def depends_on(self, parent_node, dependency_node, type):
dependencies = self._specs[parent_node].edges_to_dependencies(name=dependency_node)
# TODO: assertion to be removed when cross-compilation is handled correctly # TODO: assertion to be removed when cross-compilation is handled correctly
msg = "Current solver does not handle multiple dependency edges of the same name" msg = "Current solver does not handle multiple dependency edges of the same name"
assert len(dependencies) < 2, msg assert len(dependencies) < 2, msg
if not dependencies: if not dependencies:
self._specs[pkg].add_dependency_edge(self._specs[dep], deptypes=(type,), virtuals=()) self._specs[parent_node].add_dependency_edge(
self._specs[dependency_node], deptypes=(type,), virtuals=()
)
else: else:
# TODO: This assumes that each solve unifies dependencies # TODO: This assumes that each solve unifies dependencies
dependencies[0].update_deptypes(deptypes=(type,)) dependencies[0].update_deptypes(deptypes=(type,))
def virtual_on_edge(self, pkg, provider, virtual): def virtual_on_edge(self, parent_node, provider_node, virtual):
dependencies = self._specs[pkg].edges_to_dependencies(name=provider) provider = self.extract_pkg(provider_node)
dependencies = self._specs[parent_node].edges_to_dependencies(name=provider)
assert len(dependencies) == 1 assert len(dependencies) == 1
dependencies[0].update_virtuals((virtual,)) dependencies[0].update_virtuals((virtual,))
@ -2562,17 +2597,22 @@ def reorder_flags(self):
# order is determined by the DAG. A spec's flags come after any of its ancestors # order is determined by the DAG. A spec's flags come after any of its ancestors
# on the compile line # on the compile line
source_key = (spec.name, flag_type) node = SpecBuilder.root_node(pkg=spec.name)
source_key = (node, flag_type)
if source_key in self._flag_sources: if source_key in self._flag_sources:
order = [s.name for s in spec.traverse(order="post", direction="parents")] order = [
SpecBuilder.root_node(pkg=s.name)
for s in spec.traverse(order="post", direction="parents")
]
sorted_sources = sorted( sorted_sources = sorted(
self._flag_sources[source_key], key=lambda s: order.index(s) self._flag_sources[source_key], key=lambda s: order.index(s)
) )
# add flags from each source, lowest to highest precedence # add flags from each source, lowest to highest precedence
for name in sorted_sources: for node in sorted_sources:
all_src_flags = list() all_src_flags = list()
per_pkg_sources = [self._specs[name]] per_pkg_sources = [self._specs[node]]
name = self.extract_pkg(node)
if name in cmd_specs: if name in cmd_specs:
per_pkg_sources.append(cmd_specs[name]) per_pkg_sources.append(cmd_specs[name])
for source in per_pkg_sources: for source in per_pkg_sources:
@ -2645,14 +2685,14 @@ def build_specs(self, function_tuples):
# solving but don't construct anything. Do not ignore error # solving but don't construct anything. Do not ignore error
# predicates on virtual packages. # predicates on virtual packages.
if name != "error": if name != "error":
pkg = args[0] pkg = self.extract_pkg(args[0])
if spack.repo.PATH.is_virtual(pkg): if spack.repo.PATH.is_virtual(pkg):
continue continue
# if we've already gotten a concrete spec for this pkg, # if we've already gotten a concrete spec for this pkg,
# do not bother calling actions on it except for node_flag_source, # do not bother calling actions on it except for node_flag_source,
# since node_flag_source is tracking information not in the spec itself # since node_flag_source is tracking information not in the spec itself
spec = self._specs.get(pkg) spec = self._specs.get(args[0])
if spec and spec.concrete: if spec and spec.concrete:
if name != "node_flag_source": if name != "node_flag_source":
continue continue

File diff suppressed because it is too large Load Diff

View File

@ -2983,9 +2983,10 @@ def _new_concretize(self, tests=False):
providers = [spec.name for spec in answer.values() if spec.package.provides(name)] providers = [spec.name for spec in answer.values() if spec.package.provides(name)]
name = providers[0] name = providers[0]
assert name in answer node = spack.solver.asp.SpecBuilder.root_node(pkg=name)
assert node in answer, f"cannot find {name} in the list of specs {','.join(answer.keys())}"
concretized = answer[name] concretized = answer[node]
self._dup(concretized) self._dup(concretized)
def concretize(self, tests=False): def concretize(self, tests=False):
@ -3519,7 +3520,8 @@ def update_variant_validate(self, variant_name, values):
for value in values: for value in values:
if self.variants.get(variant_name): if self.variants.get(variant_name):
msg = ( msg = (
"Cannot append a value to a single-valued " "variant with an already set value" f"cannot append the new value '{value}' to the single-valued "
f"variant '{self.variants[variant_name]}'"
) )
assert pkg_variant.multi, msg assert pkg_variant.multi, msg
self.variants[variant_name].append(value) self.variants[variant_name].append(value)