Compare commits

...

2 Commits

Author SHA1 Message Date
Harmen Stoppels
10acffc92e fix incorrect type annotation of spack.provider_index._IndexBase.providers 2025-01-29 17:38:13 +01:00
Harmen Stoppels
f95e246355 index: avoid quadratic complexity through bulk update 2025-01-29 17:14:34 +01:00
5 changed files with 86 additions and 101 deletions

View File

@ -6,7 +6,7 @@
import os
import pathlib
import sys
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Set, Tuple, Type, Union
import llnl.util.filesystem
from llnl.url import allowed_archive
@ -503,18 +503,19 @@ def patch_for_package(self, sha256: str, pkg: "spack.package_base.PackageBase")
patch_dict["sha256"] = sha256
return from_dict(patch_dict, repository=self.repository)
def update_package(self, pkg_fullname: str) -> None:
def update_packages(self, pkgs_fullname: Set[str]) -> None:
"""Update the patch cache.
Args:
pkg_fullname: package to update.
"""
# remove this package from any patch entries that reference it.
if self.index:
empty = []
for sha256, package_to_patch in self.index.items():
remove = []
for fullname, patch_dict in package_to_patch.items():
if patch_dict["owner"] == pkg_fullname:
if patch_dict["owner"] in pkgs_fullname:
remove.append(fullname)
for fullname in remove:
@ -528,6 +529,7 @@ def update_package(self, pkg_fullname: str) -> None:
del self.index[sha256]
# update the index with per-package patch indexes
for pkg_fullname in pkgs_fullname:
pkg_cls = self.repository.get_pkg_class(pkg_fullname)
partial_index = self._index_patches(pkg_cls, self.repository)
for sha256, package_to_patch in partial_index.items():

View File

@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
"""Classes and functions to manage providers of virtual dependencies"""
from typing import Dict, List, Optional, Set
from typing import Dict, Iterable, List, Optional, Set, Union
import spack.error
import spack.spec
@ -26,7 +26,7 @@ class _IndexBase:
#: Calling providers_for(spec) will find specs that provide a
#: matching implementation of MPI. Derived class need to construct
#: this attribute according to the semantics above.
providers: Dict[str, Dict[str, Set[str]]]
providers: Dict[str, Dict["spack.spec.Spec", Set["spack.spec.Spec"]]]
def providers_for(self, virtual_spec):
"""Return a list of specs of all packages that provide virtual
@ -99,32 +99,22 @@ def __init__(
self.repository = repository
self.restrict = restrict
self.providers = {}
if specs:
self.update_packages(specs)
specs = specs or []
for spec in specs:
if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
if self.repository.is_virtual_safe(spec.name):
continue
self.update(spec)
def update(self, spec):
def update_packages(self, specs: Iterable[Union[str, "spack.spec.Spec"]]):
"""Update the provider index with additional virtual specs.
Args:
spec: spec potentially providing additional virtual specs
"""
for spec in specs:
if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
if not spec.name:
# Empty specs do not have a package
return
msg = "cannot update an index passing the virtual spec '{}'".format(spec.name)
assert not self.repository.is_virtual_safe(spec.name), msg
if not spec.name or self.repository.is_virtual_safe(spec.name):
# Only non-virtual packages with name can provide virtual specs.
continue
pkg_provided = self.repository.get_pkg_class(spec.name).provided
for provider_spec_readonly, provided_specs in pkg_provided.items():
@ -147,7 +137,7 @@ def update(self, spec):
# If this package existed in the index before,
# need to take the old versions out, as they're
# now more constrained.
old = set([s for s in provider_set if s.name == spec.name])
old = {s for s in provider_set if s.name == spec.name}
provider_set.difference_update(old)
# Now add the new version.
@ -193,14 +183,13 @@ def merge(self, other):
spdict[provided_spec] = spdict[provided_spec].union(opdict[provided_spec])
def remove_provider(self, pkg_name):
def remove_providers(self, pkgs_fullname: Set[str]):
"""Remove a provider from the ProviderIndex."""
empty_pkg_dict = []
for pkg, pkg_dict in self.providers.items():
empty_pset = []
for provided, pset in pkg_dict.items():
same_name = set(p for p in pset if p.fullname == pkg_name)
pset.difference_update(same_name)
pset.difference_update(pkgs_fullname)
if not pset:
empty_pset.append(provided)

View File

@ -465,7 +465,7 @@ def read(self, stream):
"""Read this index from a provided file object."""
@abc.abstractmethod
def update(self, pkg_fullname):
def update(self, pkgs_fullname: Set[str]):
"""Update the index in memory with information about a package."""
@abc.abstractmethod
@ -482,8 +482,8 @@ def _create(self):
def read(self, stream):
self.index = spack.tag.TagIndex.from_json(stream, self.repository)
def update(self, pkg_fullname):
self.index.update_package(pkg_fullname.split(".")[-1])
def update(self, pkgs_fullname: Set[str]):
self.index.update_packages({p.split(".")[-1] for p in pkgs_fullname})
def write(self, stream):
self.index.to_json(stream)
@ -498,15 +498,14 @@ def _create(self):
def read(self, stream):
self.index = spack.provider_index.ProviderIndex.from_json(stream, self.repository)
def update(self, pkg_fullname):
name = pkg_fullname.split(".")[-1]
def update(self, pkgs_fullname: Set[str]):
is_virtual = (
not self.repository.exists(name) or self.repository.get_pkg_class(name).virtual
lambda name: not self.repository.exists(name)
or self.repository.get_pkg_class(name).virtual
)
if is_virtual:
return
self.index.remove_provider(pkg_fullname)
self.index.update(pkg_fullname)
non_virtual_pkgs_fullname = {p for p in pkgs_fullname if not is_virtual(p.split(".")[-1])}
self.index.remove_providers(non_virtual_pkgs_fullname)
self.index.update_packages(non_virtual_pkgs_fullname)
def write(self, stream):
self.index.to_json(stream)
@ -531,8 +530,8 @@ def read(self, stream):
def write(self, stream):
self.index.to_json(stream)
def update(self, pkg_fullname):
self.index.update_package(pkg_fullname)
def update(self, pkgs_fullname: Set[str]):
self.index.update_packages(pkgs_fullname)
class RepoIndex:
@ -622,9 +621,7 @@ def _build_index(self, name: str, indexer: Indexer):
if new_index_mtime != index_mtime:
needs_update = self.checker.modified_since(new_index_mtime)
for pkg_name in needs_update:
indexer.update(f"{self.namespace}.{pkg_name}")
indexer.update({f"{self.namespace}.{pkg_name}" for pkg_name in needs_update})
indexer.write(new)
return indexer.index

View File

@ -5,6 +5,7 @@
import collections
import copy
from collections.abc import Mapping
from typing import Set
import spack.error
import spack.repo
@ -110,20 +111,17 @@ def merge(self, other):
spkgs, opkgs = self.tags[tag], other.tags[tag]
self.tags[tag] = sorted(list(set(spkgs + opkgs)))
def update_package(self, pkg_name):
"""Updates a package in the tag index.
Args:
pkg_name (str): name of the package to be removed from the index
"""
pkg_cls = self.repository.get_pkg_class(pkg_name)
def update_packages(self, pkg_names: Set[str]):
"""Updates a package in the tag index."""
# Remove the package from the list of packages, if present
for pkg_list in self._tag_dict.values():
if pkg_name in pkg_list:
pkg_list.remove(pkg_name)
if pkg_names.isdisjoint(pkg_list):
continue
pkg_list[:] = [pkg for pkg in pkg_list if pkg not in pkg_names]
# Add it again under the appropriate tags
for pkg_name in pkg_names:
pkg_cls = self.repository.get_pkg_class(pkg_name)
for tag in getattr(pkg_cls, "tags", []):
tag = tag.lower()
self._tag_dict[tag].append(pkg_cls.name)

View File

@ -154,7 +154,6 @@ def test_tag_no_tags(mock_packages):
def test_tag_update_package(mock_packages):
mock_index = mock_packages.tag_index
index = spack.tag.TagIndex(repository=mock_packages)
for name in spack.repo.all_package_names():
index.update_package(name)
index.update_packages(set(spack.repo.all_package_names()))
ensure_tags_results_equal(mock_index.tags, index.tags)