diff --git a/lib/spack/spack/patch.py b/lib/spack/spack/patch.py index a909bca1056..06c6cae48f2 100644 --- a/lib/spack/spack/patch.py +++ b/lib/spack/spack/patch.py @@ -6,7 +6,7 @@ import os import pathlib import sys -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Set, Tuple, Type, Union import llnl.util.filesystem from llnl.url import allowed_archive @@ -503,36 +503,38 @@ def patch_for_package(self, sha256: str, pkg: "spack.package_base.PackageBase") patch_dict["sha256"] = sha256 return from_dict(patch_dict, repository=self.repository) - def update_package(self, pkg_fullname: str) -> None: + def update_packages(self, pkgs_fullname: Set[str]) -> None: """Update the patch cache. Args: pkg_fullname: package to update. """ # remove this package from any patch entries that reference it. - empty = [] - for sha256, package_to_patch in self.index.items(): - remove = [] - for fullname, patch_dict in package_to_patch.items(): - if patch_dict["owner"] == pkg_fullname: - remove.append(fullname) + if self.index: + empty = [] + for sha256, package_to_patch in self.index.items(): + remove = [] + for fullname, patch_dict in package_to_patch.items(): + if patch_dict["owner"] in pkgs_fullname: + remove.append(fullname) - for fullname in remove: - package_to_patch.pop(fullname) + for fullname in remove: + package_to_patch.pop(fullname) - if not package_to_patch: - empty.append(sha256) + if not package_to_patch: + empty.append(sha256) - # remove any entries that are now empty - for sha256 in empty: - del self.index[sha256] + # remove any entries that are now empty + for sha256 in empty: + del self.index[sha256] # update the index with per-package patch indexes - pkg_cls = self.repository.get_pkg_class(pkg_fullname) - partial_index = self._index_patches(pkg_cls, self.repository) - for sha256, package_to_patch in partial_index.items(): - p2p = self.index.setdefault(sha256, {}) - p2p.update(package_to_patch) + for pkg_fullname in pkgs_fullname: + pkg_cls = self.repository.get_pkg_class(pkg_fullname) + partial_index = self._index_patches(pkg_cls, self.repository) + for sha256, package_to_patch in partial_index.items(): + p2p = self.index.setdefault(sha256, {}) + p2p.update(package_to_patch) def update(self, other: "PatchCache") -> None: """Update this cache with the contents of another. diff --git a/lib/spack/spack/provider_index.py b/lib/spack/spack/provider_index.py index 3a6431bba0b..5a1548d21b4 100644 --- a/lib/spack/spack/provider_index.py +++ b/lib/spack/spack/provider_index.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: (Apache-2.0 OR MIT) """Classes and functions to manage providers of virtual dependencies""" -from typing import Dict, List, Optional, Set +from typing import Dict, Iterable, List, Optional, Set, Union import spack.error import spack.spec @@ -99,66 +99,56 @@ def __init__( self.repository = repository self.restrict = restrict self.providers = {} + if specs: + self.update_packages(specs) - specs = specs or [] - for spec in specs: - if not isinstance(spec, spack.spec.Spec): - spec = spack.spec.Spec(spec) - - if self.repository.is_virtual_safe(spec.name): - continue - - self.update(spec) - - def update(self, spec): + def update_packages(self, specs: Iterable[Union[str, "spack.spec.Spec"]]): """Update the provider index with additional virtual specs. Args: spec: spec potentially providing additional virtual specs """ - if not isinstance(spec, spack.spec.Spec): - spec = spack.spec.Spec(spec) + for spec in specs: + if not isinstance(spec, spack.spec.Spec): + spec = spack.spec.Spec(spec) - if not spec.name: - # Empty specs do not have a package - return + if not spec.name or self.repository.is_virtual_safe(spec.name): + # Only non-virtual packages with name can provide virtual specs. + continue - msg = "cannot update an index passing the virtual spec '{}'".format(spec.name) - assert not self.repository.is_virtual_safe(spec.name), msg + pkg_provided = self.repository.get_pkg_class(spec.name).provided + for provider_spec_readonly, provided_specs in pkg_provided.items(): + for provided_spec in provided_specs: + # TODO: fix this comment. + # We want satisfaction other than flags + provider_spec = provider_spec_readonly.copy() + provider_spec.compiler_flags = spec.compiler_flags.copy() - pkg_provided = self.repository.get_pkg_class(spec.name).provided - for provider_spec_readonly, provided_specs in pkg_provided.items(): - for provided_spec in provided_specs: - # TODO: fix this comment. - # We want satisfaction other than flags - provider_spec = provider_spec_readonly.copy() - provider_spec.compiler_flags = spec.compiler_flags.copy() + if spec.intersects(provider_spec, deps=False): + provided_name = provided_spec.name - if spec.intersects(provider_spec, deps=False): - provided_name = provided_spec.name + provider_map = self.providers.setdefault(provided_name, {}) + if provided_spec not in provider_map: + provider_map[provided_spec] = set() - provider_map = self.providers.setdefault(provided_name, {}) - if provided_spec not in provider_map: - provider_map[provided_spec] = set() + if self.restrict: + provider_set = provider_map[provided_spec] - if self.restrict: - provider_set = provider_map[provided_spec] + # If this package existed in the index before, + # need to take the old versions out, as they're + # now more constrained. + old = {s for s in provider_set if s.name == spec.name} + provider_set.difference_update(old) - # If this package existed in the index before, - # need to take the old versions out, as they're - # now more constrained. - old = set([s for s in provider_set if s.name == spec.name]) - provider_set.difference_update(old) + # Now add the new version. + provider_set.add(spec) - # Now add the new version. - provider_set.add(spec) - - else: - # Before putting the spec in the map, constrain - # it so that it provides what was asked for. - constrained = spec.copy() - constrained.constrain(provider_spec) - provider_map[provided_spec].add(constrained) + else: + # Before putting the spec in the map, constrain + # it so that it provides what was asked for. + constrained = spec.copy() + constrained.constrain(provider_spec) + provider_map[provided_spec].add(constrained) def to_json(self, stream=None): """Dump a JSON representation of this object. @@ -193,14 +183,13 @@ def merge(self, other): spdict[provided_spec] = spdict[provided_spec].union(opdict[provided_spec]) - def remove_provider(self, pkg_name): + def remove_providers(self, pkgs_fullname: Set[str]): """Remove a provider from the ProviderIndex.""" empty_pkg_dict = [] for pkg, pkg_dict in self.providers.items(): empty_pset = [] for provided, pset in pkg_dict.items(): - same_name = set(p for p in pset if p.fullname == pkg_name) - pset.difference_update(same_name) + pset.difference_update(pkgs_fullname) if not pset: empty_pset.append(provided) diff --git a/lib/spack/spack/repo.py b/lib/spack/spack/repo.py index 62aaafaf638..7ef6364c253 100644 --- a/lib/spack/spack/repo.py +++ b/lib/spack/spack/repo.py @@ -465,7 +465,7 @@ def read(self, stream): """Read this index from a provided file object.""" @abc.abstractmethod - def update(self, pkg_fullname): + def update(self, pkgs_fullname: Set[str]): """Update the index in memory with information about a package.""" @abc.abstractmethod @@ -482,8 +482,8 @@ def _create(self): def read(self, stream): self.index = spack.tag.TagIndex.from_json(stream, self.repository) - def update(self, pkg_fullname): - self.index.update_package(pkg_fullname.split(".")[-1]) + def update(self, pkgs_fullname: Set[str]): + self.index.update_packages({p.split(".")[-1] for p in pkgs_fullname}) def write(self, stream): self.index.to_json(stream) @@ -498,15 +498,14 @@ def _create(self): def read(self, stream): self.index = spack.provider_index.ProviderIndex.from_json(stream, self.repository) - def update(self, pkg_fullname): - name = pkg_fullname.split(".")[-1] + def update(self, pkgs_fullname: Set[str]): is_virtual = ( - not self.repository.exists(name) or self.repository.get_pkg_class(name).virtual + lambda name: not self.repository.exists(name) + or self.repository.get_pkg_class(name).virtual ) - if is_virtual: - return - self.index.remove_provider(pkg_fullname) - self.index.update(pkg_fullname) + non_virtual_pkgs_fullname = {p for p in pkgs_fullname if not is_virtual(p.split(".")[-1])} + self.index.remove_providers(non_virtual_pkgs_fullname) + self.index.update_packages(non_virtual_pkgs_fullname) def write(self, stream): self.index.to_json(stream) @@ -531,8 +530,8 @@ def read(self, stream): def write(self, stream): self.index.to_json(stream) - def update(self, pkg_fullname): - self.index.update_package(pkg_fullname) + def update(self, pkgs_fullname: Set[str]): + self.index.update_packages(pkgs_fullname) class RepoIndex: @@ -622,9 +621,7 @@ def _build_index(self, name: str, indexer: Indexer): if new_index_mtime != index_mtime: needs_update = self.checker.modified_since(new_index_mtime) - for pkg_name in needs_update: - indexer.update(f"{self.namespace}.{pkg_name}") - + indexer.update({f"{self.namespace}.{pkg_name}" for pkg_name in needs_update}) indexer.write(new) return indexer.index diff --git a/lib/spack/spack/tag.py b/lib/spack/spack/tag.py index 8768ea39be0..2005b33e410 100644 --- a/lib/spack/spack/tag.py +++ b/lib/spack/spack/tag.py @@ -5,6 +5,7 @@ import collections import copy from collections.abc import Mapping +from typing import Set import spack.error import spack.repo @@ -110,23 +111,20 @@ def merge(self, other): spkgs, opkgs = self.tags[tag], other.tags[tag] self.tags[tag] = sorted(list(set(spkgs + opkgs))) - def update_package(self, pkg_name): - """Updates a package in the tag index. - - Args: - pkg_name (str): name of the package to be removed from the index - """ - pkg_cls = self.repository.get_pkg_class(pkg_name) - + def update_packages(self, pkg_names: Set[str]): + """Updates a package in the tag index.""" # Remove the package from the list of packages, if present for pkg_list in self._tag_dict.values(): - if pkg_name in pkg_list: - pkg_list.remove(pkg_name) + if pkg_names.isdisjoint(pkg_list): + continue + pkg_list[:] = [pkg for pkg in pkg_list if pkg not in pkg_names] # Add it again under the appropriate tags - for tag in getattr(pkg_cls, "tags", []): - tag = tag.lower() - self._tag_dict[tag].append(pkg_cls.name) + for pkg_name in pkg_names: + pkg_cls = self.repository.get_pkg_class(pkg_name) + for tag in getattr(pkg_cls, "tags", []): + tag = tag.lower() + self._tag_dict[tag].append(pkg_cls.name) class TagIndexError(spack.error.SpackError): diff --git a/lib/spack/spack/test/tag.py b/lib/spack/spack/test/tag.py index fe4d93dc2a4..5c497fce6cf 100644 --- a/lib/spack/spack/test/tag.py +++ b/lib/spack/spack/test/tag.py @@ -154,7 +154,6 @@ def test_tag_no_tags(mock_packages): def test_tag_update_package(mock_packages): mock_index = mock_packages.tag_index index = spack.tag.TagIndex(repository=mock_packages) - for name in spack.repo.all_package_names(): - index.update_package(name) + index.update_packages(set(spack.repo.all_package_names())) ensure_tags_results_equal(mock_index.tags, index.tags)