index: avoid quadratic complexity through bulk update

This commit is contained in:
Harmen Stoppels 2025-01-29 17:07:29 +01:00
parent a77f903f4d
commit f95e246355
5 changed files with 85 additions and 100 deletions

View File

@ -6,7 +6,7 @@
import os import os
import pathlib import pathlib
import sys import sys
from typing import Any, Dict, Optional, Tuple, Type, Union from typing import Any, Dict, Optional, Set, Tuple, Type, Union
import llnl.util.filesystem import llnl.util.filesystem
from llnl.url import allowed_archive from llnl.url import allowed_archive
@ -503,36 +503,38 @@ def patch_for_package(self, sha256: str, pkg: "spack.package_base.PackageBase")
patch_dict["sha256"] = sha256 patch_dict["sha256"] = sha256
return from_dict(patch_dict, repository=self.repository) return from_dict(patch_dict, repository=self.repository)
def update_package(self, pkg_fullname: str) -> None: def update_packages(self, pkgs_fullname: Set[str]) -> None:
"""Update the patch cache. """Update the patch cache.
Args: Args:
pkg_fullname: package to update. pkg_fullname: package to update.
""" """
# remove this package from any patch entries that reference it. # remove this package from any patch entries that reference it.
empty = [] if self.index:
for sha256, package_to_patch in self.index.items(): empty = []
remove = [] for sha256, package_to_patch in self.index.items():
for fullname, patch_dict in package_to_patch.items(): remove = []
if patch_dict["owner"] == pkg_fullname: for fullname, patch_dict in package_to_patch.items():
remove.append(fullname) if patch_dict["owner"] in pkgs_fullname:
remove.append(fullname)
for fullname in remove: for fullname in remove:
package_to_patch.pop(fullname) package_to_patch.pop(fullname)
if not package_to_patch: if not package_to_patch:
empty.append(sha256) empty.append(sha256)
# remove any entries that are now empty # remove any entries that are now empty
for sha256 in empty: for sha256 in empty:
del self.index[sha256] del self.index[sha256]
# update the index with per-package patch indexes # update the index with per-package patch indexes
pkg_cls = self.repository.get_pkg_class(pkg_fullname) for pkg_fullname in pkgs_fullname:
partial_index = self._index_patches(pkg_cls, self.repository) pkg_cls = self.repository.get_pkg_class(pkg_fullname)
for sha256, package_to_patch in partial_index.items(): partial_index = self._index_patches(pkg_cls, self.repository)
p2p = self.index.setdefault(sha256, {}) for sha256, package_to_patch in partial_index.items():
p2p.update(package_to_patch) p2p = self.index.setdefault(sha256, {})
p2p.update(package_to_patch)
def update(self, other: "PatchCache") -> None: def update(self, other: "PatchCache") -> None:
"""Update this cache with the contents of another. """Update this cache with the contents of another.

View File

@ -2,7 +2,7 @@
# #
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
"""Classes and functions to manage providers of virtual dependencies""" """Classes and functions to manage providers of virtual dependencies"""
from typing import Dict, List, Optional, Set from typing import Dict, Iterable, List, Optional, Set, Union
import spack.error import spack.error
import spack.spec import spack.spec
@ -99,66 +99,56 @@ def __init__(
self.repository = repository self.repository = repository
self.restrict = restrict self.restrict = restrict
self.providers = {} self.providers = {}
if specs:
self.update_packages(specs)
specs = specs or [] def update_packages(self, specs: Iterable[Union[str, "spack.spec.Spec"]]):
for spec in specs:
if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
if self.repository.is_virtual_safe(spec.name):
continue
self.update(spec)
def update(self, spec):
"""Update the provider index with additional virtual specs. """Update the provider index with additional virtual specs.
Args: Args:
spec: spec potentially providing additional virtual specs spec: spec potentially providing additional virtual specs
""" """
if not isinstance(spec, spack.spec.Spec): for spec in specs:
spec = spack.spec.Spec(spec) if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
if not spec.name: if not spec.name or self.repository.is_virtual_safe(spec.name):
# Empty specs do not have a package # Only non-virtual packages with name can provide virtual specs.
return continue
msg = "cannot update an index passing the virtual spec '{}'".format(spec.name) pkg_provided = self.repository.get_pkg_class(spec.name).provided
assert not self.repository.is_virtual_safe(spec.name), msg for provider_spec_readonly, provided_specs in pkg_provided.items():
for provided_spec in provided_specs:
# TODO: fix this comment.
# We want satisfaction other than flags
provider_spec = provider_spec_readonly.copy()
provider_spec.compiler_flags = spec.compiler_flags.copy()
pkg_provided = self.repository.get_pkg_class(spec.name).provided if spec.intersects(provider_spec, deps=False):
for provider_spec_readonly, provided_specs in pkg_provided.items(): provided_name = provided_spec.name
for provided_spec in provided_specs:
# TODO: fix this comment.
# We want satisfaction other than flags
provider_spec = provider_spec_readonly.copy()
provider_spec.compiler_flags = spec.compiler_flags.copy()
if spec.intersects(provider_spec, deps=False): provider_map = self.providers.setdefault(provided_name, {})
provided_name = provided_spec.name if provided_spec not in provider_map:
provider_map[provided_spec] = set()
provider_map = self.providers.setdefault(provided_name, {}) if self.restrict:
if provided_spec not in provider_map: provider_set = provider_map[provided_spec]
provider_map[provided_spec] = set()
if self.restrict: # If this package existed in the index before,
provider_set = provider_map[provided_spec] # need to take the old versions out, as they're
# now more constrained.
old = {s for s in provider_set if s.name == spec.name}
provider_set.difference_update(old)
# If this package existed in the index before, # Now add the new version.
# need to take the old versions out, as they're provider_set.add(spec)
# now more constrained.
old = set([s for s in provider_set if s.name == spec.name])
provider_set.difference_update(old)
# Now add the new version. else:
provider_set.add(spec) # Before putting the spec in the map, constrain
# it so that it provides what was asked for.
else: constrained = spec.copy()
# Before putting the spec in the map, constrain constrained.constrain(provider_spec)
# it so that it provides what was asked for. provider_map[provided_spec].add(constrained)
constrained = spec.copy()
constrained.constrain(provider_spec)
provider_map[provided_spec].add(constrained)
def to_json(self, stream=None): def to_json(self, stream=None):
"""Dump a JSON representation of this object. """Dump a JSON representation of this object.
@ -193,14 +183,13 @@ def merge(self, other):
spdict[provided_spec] = spdict[provided_spec].union(opdict[provided_spec]) spdict[provided_spec] = spdict[provided_spec].union(opdict[provided_spec])
def remove_provider(self, pkg_name): def remove_providers(self, pkgs_fullname: Set[str]):
"""Remove a provider from the ProviderIndex.""" """Remove a provider from the ProviderIndex."""
empty_pkg_dict = [] empty_pkg_dict = []
for pkg, pkg_dict in self.providers.items(): for pkg, pkg_dict in self.providers.items():
empty_pset = [] empty_pset = []
for provided, pset in pkg_dict.items(): for provided, pset in pkg_dict.items():
same_name = set(p for p in pset if p.fullname == pkg_name) pset.difference_update(pkgs_fullname)
pset.difference_update(same_name)
if not pset: if not pset:
empty_pset.append(provided) empty_pset.append(provided)

View File

@ -465,7 +465,7 @@ def read(self, stream):
"""Read this index from a provided file object.""" """Read this index from a provided file object."""
@abc.abstractmethod @abc.abstractmethod
def update(self, pkg_fullname): def update(self, pkgs_fullname: Set[str]):
"""Update the index in memory with information about a package.""" """Update the index in memory with information about a package."""
@abc.abstractmethod @abc.abstractmethod
@ -482,8 +482,8 @@ def _create(self):
def read(self, stream): def read(self, stream):
self.index = spack.tag.TagIndex.from_json(stream, self.repository) self.index = spack.tag.TagIndex.from_json(stream, self.repository)
def update(self, pkg_fullname): def update(self, pkgs_fullname: Set[str]):
self.index.update_package(pkg_fullname.split(".")[-1]) self.index.update_packages({p.split(".")[-1] for p in pkgs_fullname})
def write(self, stream): def write(self, stream):
self.index.to_json(stream) self.index.to_json(stream)
@ -498,15 +498,14 @@ def _create(self):
def read(self, stream): def read(self, stream):
self.index = spack.provider_index.ProviderIndex.from_json(stream, self.repository) self.index = spack.provider_index.ProviderIndex.from_json(stream, self.repository)
def update(self, pkg_fullname): def update(self, pkgs_fullname: Set[str]):
name = pkg_fullname.split(".")[-1]
is_virtual = ( is_virtual = (
not self.repository.exists(name) or self.repository.get_pkg_class(name).virtual lambda name: not self.repository.exists(name)
or self.repository.get_pkg_class(name).virtual
) )
if is_virtual: non_virtual_pkgs_fullname = {p for p in pkgs_fullname if not is_virtual(p.split(".")[-1])}
return self.index.remove_providers(non_virtual_pkgs_fullname)
self.index.remove_provider(pkg_fullname) self.index.update_packages(non_virtual_pkgs_fullname)
self.index.update(pkg_fullname)
def write(self, stream): def write(self, stream):
self.index.to_json(stream) self.index.to_json(stream)
@ -531,8 +530,8 @@ def read(self, stream):
def write(self, stream): def write(self, stream):
self.index.to_json(stream) self.index.to_json(stream)
def update(self, pkg_fullname): def update(self, pkgs_fullname: Set[str]):
self.index.update_package(pkg_fullname) self.index.update_packages(pkgs_fullname)
class RepoIndex: class RepoIndex:
@ -622,9 +621,7 @@ def _build_index(self, name: str, indexer: Indexer):
if new_index_mtime != index_mtime: if new_index_mtime != index_mtime:
needs_update = self.checker.modified_since(new_index_mtime) needs_update = self.checker.modified_since(new_index_mtime)
for pkg_name in needs_update: indexer.update({f"{self.namespace}.{pkg_name}" for pkg_name in needs_update})
indexer.update(f"{self.namespace}.{pkg_name}")
indexer.write(new) indexer.write(new)
return indexer.index return indexer.index

View File

@ -5,6 +5,7 @@
import collections import collections
import copy import copy
from collections.abc import Mapping from collections.abc import Mapping
from typing import Set
import spack.error import spack.error
import spack.repo import spack.repo
@ -110,23 +111,20 @@ def merge(self, other):
spkgs, opkgs = self.tags[tag], other.tags[tag] spkgs, opkgs = self.tags[tag], other.tags[tag]
self.tags[tag] = sorted(list(set(spkgs + opkgs))) self.tags[tag] = sorted(list(set(spkgs + opkgs)))
def update_package(self, pkg_name): def update_packages(self, pkg_names: Set[str]):
"""Updates a package in the tag index. """Updates a package in the tag index."""
Args:
pkg_name (str): name of the package to be removed from the index
"""
pkg_cls = self.repository.get_pkg_class(pkg_name)
# Remove the package from the list of packages, if present # Remove the package from the list of packages, if present
for pkg_list in self._tag_dict.values(): for pkg_list in self._tag_dict.values():
if pkg_name in pkg_list: if pkg_names.isdisjoint(pkg_list):
pkg_list.remove(pkg_name) continue
pkg_list[:] = [pkg for pkg in pkg_list if pkg not in pkg_names]
# Add it again under the appropriate tags # Add it again under the appropriate tags
for tag in getattr(pkg_cls, "tags", []): for pkg_name in pkg_names:
tag = tag.lower() pkg_cls = self.repository.get_pkg_class(pkg_name)
self._tag_dict[tag].append(pkg_cls.name) for tag in getattr(pkg_cls, "tags", []):
tag = tag.lower()
self._tag_dict[tag].append(pkg_cls.name)
class TagIndexError(spack.error.SpackError): class TagIndexError(spack.error.SpackError):

View File

@ -154,7 +154,6 @@ def test_tag_no_tags(mock_packages):
def test_tag_update_package(mock_packages): def test_tag_update_package(mock_packages):
mock_index = mock_packages.tag_index mock_index = mock_packages.tag_index
index = spack.tag.TagIndex(repository=mock_packages) index = spack.tag.TagIndex(repository=mock_packages)
for name in spack.repo.all_package_names(): index.update_packages(set(spack.repo.all_package_names()))
index.update_package(name)
ensure_tags_results_equal(mock_index.tags, index.tags) ensure_tags_results_equal(mock_index.tags, index.tags)