Compare commits

...

2 Commits

Author SHA1 Message Date
Harmen Stoppels
10acffc92e fix incorrect type annotation of spack.provider_index._IndexBase.providers 2025-01-29 17:38:13 +01:00
Harmen Stoppels
f95e246355 index: avoid quadratic complexity through bulk update 2025-01-29 17:14:34 +01:00
5 changed files with 86 additions and 101 deletions

View File

@ -6,7 +6,7 @@
import os import os
import pathlib import pathlib
import sys import sys
from typing import Any, Dict, Optional, Tuple, Type, Union from typing import Any, Dict, Optional, Set, Tuple, Type, Union
import llnl.util.filesystem import llnl.util.filesystem
from llnl.url import allowed_archive from llnl.url import allowed_archive
@ -503,18 +503,19 @@ def patch_for_package(self, sha256: str, pkg: "spack.package_base.PackageBase")
patch_dict["sha256"] = sha256 patch_dict["sha256"] = sha256
return from_dict(patch_dict, repository=self.repository) return from_dict(patch_dict, repository=self.repository)
def update_package(self, pkg_fullname: str) -> None: def update_packages(self, pkgs_fullname: Set[str]) -> None:
"""Update the patch cache. """Update the patch cache.
Args: Args:
pkg_fullname: package to update. pkg_fullname: package to update.
""" """
# remove this package from any patch entries that reference it. # remove this package from any patch entries that reference it.
if self.index:
empty = [] empty = []
for sha256, package_to_patch in self.index.items(): for sha256, package_to_patch in self.index.items():
remove = [] remove = []
for fullname, patch_dict in package_to_patch.items(): for fullname, patch_dict in package_to_patch.items():
if patch_dict["owner"] == pkg_fullname: if patch_dict["owner"] in pkgs_fullname:
remove.append(fullname) remove.append(fullname)
for fullname in remove: for fullname in remove:
@ -528,6 +529,7 @@ def update_package(self, pkg_fullname: str) -> None:
del self.index[sha256] del self.index[sha256]
# update the index with per-package patch indexes # update the index with per-package patch indexes
for pkg_fullname in pkgs_fullname:
pkg_cls = self.repository.get_pkg_class(pkg_fullname) pkg_cls = self.repository.get_pkg_class(pkg_fullname)
partial_index = self._index_patches(pkg_cls, self.repository) partial_index = self._index_patches(pkg_cls, self.repository)
for sha256, package_to_patch in partial_index.items(): for sha256, package_to_patch in partial_index.items():

View File

@ -2,7 +2,7 @@
# #
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
"""Classes and functions to manage providers of virtual dependencies""" """Classes and functions to manage providers of virtual dependencies"""
from typing import Dict, List, Optional, Set from typing import Dict, Iterable, List, Optional, Set, Union
import spack.error import spack.error
import spack.spec import spack.spec
@ -26,7 +26,7 @@ class _IndexBase:
#: Calling providers_for(spec) will find specs that provide a #: Calling providers_for(spec) will find specs that provide a
#: matching implementation of MPI. Derived class need to construct #: matching implementation of MPI. Derived class need to construct
#: this attribute according to the semantics above. #: this attribute according to the semantics above.
providers: Dict[str, Dict[str, Set[str]]] providers: Dict[str, Dict["spack.spec.Spec", Set["spack.spec.Spec"]]]
def providers_for(self, virtual_spec): def providers_for(self, virtual_spec):
"""Return a list of specs of all packages that provide virtual """Return a list of specs of all packages that provide virtual
@ -99,32 +99,22 @@ def __init__(
self.repository = repository self.repository = repository
self.restrict = restrict self.restrict = restrict
self.providers = {} self.providers = {}
if specs:
self.update_packages(specs)
specs = specs or [] def update_packages(self, specs: Iterable[Union[str, "spack.spec.Spec"]]):
for spec in specs:
if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
if self.repository.is_virtual_safe(spec.name):
continue
self.update(spec)
def update(self, spec):
"""Update the provider index with additional virtual specs. """Update the provider index with additional virtual specs.
Args: Args:
spec: spec potentially providing additional virtual specs spec: spec potentially providing additional virtual specs
""" """
for spec in specs:
if not isinstance(spec, spack.spec.Spec): if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec) spec = spack.spec.Spec(spec)
if not spec.name: if not spec.name or self.repository.is_virtual_safe(spec.name):
# Empty specs do not have a package # Only non-virtual packages with name can provide virtual specs.
return continue
msg = "cannot update an index passing the virtual spec '{}'".format(spec.name)
assert not self.repository.is_virtual_safe(spec.name), msg
pkg_provided = self.repository.get_pkg_class(spec.name).provided pkg_provided = self.repository.get_pkg_class(spec.name).provided
for provider_spec_readonly, provided_specs in pkg_provided.items(): for provider_spec_readonly, provided_specs in pkg_provided.items():
@ -147,7 +137,7 @@ def update(self, spec):
# If this package existed in the index before, # If this package existed in the index before,
# need to take the old versions out, as they're # need to take the old versions out, as they're
# now more constrained. # now more constrained.
old = set([s for s in provider_set if s.name == spec.name]) old = {s for s in provider_set if s.name == spec.name}
provider_set.difference_update(old) provider_set.difference_update(old)
# Now add the new version. # Now add the new version.
@ -193,14 +183,13 @@ def merge(self, other):
spdict[provided_spec] = spdict[provided_spec].union(opdict[provided_spec]) spdict[provided_spec] = spdict[provided_spec].union(opdict[provided_spec])
def remove_provider(self, pkg_name): def remove_providers(self, pkgs_fullname: Set[str]):
"""Remove a provider from the ProviderIndex.""" """Remove a provider from the ProviderIndex."""
empty_pkg_dict = [] empty_pkg_dict = []
for pkg, pkg_dict in self.providers.items(): for pkg, pkg_dict in self.providers.items():
empty_pset = [] empty_pset = []
for provided, pset in pkg_dict.items(): for provided, pset in pkg_dict.items():
same_name = set(p for p in pset if p.fullname == pkg_name) pset.difference_update(pkgs_fullname)
pset.difference_update(same_name)
if not pset: if not pset:
empty_pset.append(provided) empty_pset.append(provided)

View File

@ -465,7 +465,7 @@ def read(self, stream):
"""Read this index from a provided file object.""" """Read this index from a provided file object."""
@abc.abstractmethod @abc.abstractmethod
def update(self, pkg_fullname): def update(self, pkgs_fullname: Set[str]):
"""Update the index in memory with information about a package.""" """Update the index in memory with information about a package."""
@abc.abstractmethod @abc.abstractmethod
@ -482,8 +482,8 @@ def _create(self):
def read(self, stream): def read(self, stream):
self.index = spack.tag.TagIndex.from_json(stream, self.repository) self.index = spack.tag.TagIndex.from_json(stream, self.repository)
def update(self, pkg_fullname): def update(self, pkgs_fullname: Set[str]):
self.index.update_package(pkg_fullname.split(".")[-1]) self.index.update_packages({p.split(".")[-1] for p in pkgs_fullname})
def write(self, stream): def write(self, stream):
self.index.to_json(stream) self.index.to_json(stream)
@ -498,15 +498,14 @@ def _create(self):
def read(self, stream): def read(self, stream):
self.index = spack.provider_index.ProviderIndex.from_json(stream, self.repository) self.index = spack.provider_index.ProviderIndex.from_json(stream, self.repository)
def update(self, pkg_fullname): def update(self, pkgs_fullname: Set[str]):
name = pkg_fullname.split(".")[-1]
is_virtual = ( is_virtual = (
not self.repository.exists(name) or self.repository.get_pkg_class(name).virtual lambda name: not self.repository.exists(name)
or self.repository.get_pkg_class(name).virtual
) )
if is_virtual: non_virtual_pkgs_fullname = {p for p in pkgs_fullname if not is_virtual(p.split(".")[-1])}
return self.index.remove_providers(non_virtual_pkgs_fullname)
self.index.remove_provider(pkg_fullname) self.index.update_packages(non_virtual_pkgs_fullname)
self.index.update(pkg_fullname)
def write(self, stream): def write(self, stream):
self.index.to_json(stream) self.index.to_json(stream)
@ -531,8 +530,8 @@ def read(self, stream):
def write(self, stream): def write(self, stream):
self.index.to_json(stream) self.index.to_json(stream)
def update(self, pkg_fullname): def update(self, pkgs_fullname: Set[str]):
self.index.update_package(pkg_fullname) self.index.update_packages(pkgs_fullname)
class RepoIndex: class RepoIndex:
@ -622,9 +621,7 @@ def _build_index(self, name: str, indexer: Indexer):
if new_index_mtime != index_mtime: if new_index_mtime != index_mtime:
needs_update = self.checker.modified_since(new_index_mtime) needs_update = self.checker.modified_since(new_index_mtime)
for pkg_name in needs_update: indexer.update({f"{self.namespace}.{pkg_name}" for pkg_name in needs_update})
indexer.update(f"{self.namespace}.{pkg_name}")
indexer.write(new) indexer.write(new)
return indexer.index return indexer.index

View File

@ -5,6 +5,7 @@
import collections import collections
import copy import copy
from collections.abc import Mapping from collections.abc import Mapping
from typing import Set
import spack.error import spack.error
import spack.repo import spack.repo
@ -110,20 +111,17 @@ def merge(self, other):
spkgs, opkgs = self.tags[tag], other.tags[tag] spkgs, opkgs = self.tags[tag], other.tags[tag]
self.tags[tag] = sorted(list(set(spkgs + opkgs))) self.tags[tag] = sorted(list(set(spkgs + opkgs)))
def update_package(self, pkg_name): def update_packages(self, pkg_names: Set[str]):
"""Updates a package in the tag index. """Updates a package in the tag index."""
Args:
pkg_name (str): name of the package to be removed from the index
"""
pkg_cls = self.repository.get_pkg_class(pkg_name)
# Remove the package from the list of packages, if present # Remove the package from the list of packages, if present
for pkg_list in self._tag_dict.values(): for pkg_list in self._tag_dict.values():
if pkg_name in pkg_list: if pkg_names.isdisjoint(pkg_list):
pkg_list.remove(pkg_name) continue
pkg_list[:] = [pkg for pkg in pkg_list if pkg not in pkg_names]
# Add it again under the appropriate tags # Add it again under the appropriate tags
for pkg_name in pkg_names:
pkg_cls = self.repository.get_pkg_class(pkg_name)
for tag in getattr(pkg_cls, "tags", []): for tag in getattr(pkg_cls, "tags", []):
tag = tag.lower() tag = tag.lower()
self._tag_dict[tag].append(pkg_cls.name) self._tag_dict[tag].append(pkg_cls.name)

View File

@ -154,7 +154,6 @@ def test_tag_no_tags(mock_packages):
def test_tag_update_package(mock_packages): def test_tag_update_package(mock_packages):
mock_index = mock_packages.tag_index mock_index = mock_packages.tag_index
index = spack.tag.TagIndex(repository=mock_packages) index = spack.tag.TagIndex(repository=mock_packages)
for name in spack.repo.all_package_names(): index.update_packages(set(spack.repo.all_package_names()))
index.update_package(name)
ensure_tags_results_equal(mock_index.tags, index.tags) ensure_tags_results_equal(mock_index.tags, index.tags)