Add type-hints to RepoPath (#45068)

* Also, fix a bug with use_repositories + import spack.pkg
This commit is contained in:
Massimiliano Culpo 2024-07-08 11:48:39 +02:00 committed by GitHub
parent cef9c36183
commit 74398d74ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 121 additions and 89 deletions

View File

@ -941,9 +941,7 @@ def get_repository(args, name):
) )
else: else:
if spec.namespace: if spec.namespace:
repo = spack.repo.PATH.get_repo(spec.namespace, None) repo = spack.repo.PATH.get_repo(spec.namespace)
if not repo:
tty.die("Unknown namespace: '{0}'".format(spec.namespace))
else: else:
repo = spack.repo.PATH.first_repo() repo = spack.repo.PATH.first_repo()

View File

@ -12,7 +12,7 @@
import re import re
import sys import sys
import warnings import warnings
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple, Type
import llnl.util.filesystem import llnl.util.filesystem
import llnl.util.lang import llnl.util.lang
@ -200,7 +200,7 @@ class Finder:
def default_path_hints(self) -> List[str]: def default_path_hints(self) -> List[str]:
return [] return []
def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]: def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
"""Returns the list of patterns used to match candidate files. """Returns the list of patterns used to match candidate files.
Args: Args:
@ -226,7 +226,7 @@ def prefix_from_path(self, *, path: str) -> str:
raise NotImplementedError("must be implemented by derived classes") raise NotImplementedError("must be implemented by derived classes")
def detect_specs( def detect_specs(
self, *, pkg: "spack.package_base.PackageBase", paths: List[str] self, *, pkg: Type["spack.package_base.PackageBase"], paths: List[str]
) -> List[DetectedPackage]: ) -> List[DetectedPackage]:
"""Given a list of files matching the search patterns, returns a list of detected specs. """Given a list of files matching the search patterns, returns a list of detected specs.
@ -327,7 +327,7 @@ class ExecutablesFinder(Finder):
def default_path_hints(self) -> List[str]: def default_path_hints(self) -> List[str]:
return spack.util.environment.get_path("PATH") return spack.util.environment.get_path("PATH")
def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]: def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
result = [] result = []
if hasattr(pkg, "executables") and hasattr(pkg, "platform_executables"): if hasattr(pkg, "executables") and hasattr(pkg, "platform_executables"):
result = pkg.platform_executables() result = pkg.platform_executables()
@ -356,7 +356,7 @@ class LibrariesFinder(Finder):
DYLD_LIBRARY_PATH, DYLD_FALLBACK_LIBRARY_PATH, and standard system library paths DYLD_LIBRARY_PATH, DYLD_FALLBACK_LIBRARY_PATH, and standard system library paths
""" """
def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]: def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
result = [] result = []
if hasattr(pkg, "libraries"): if hasattr(pkg, "libraries"):
result = pkg.libraries result = pkg.libraries

View File

@ -9,7 +9,7 @@
import os.path import os.path
import pathlib import pathlib
import sys import sys
from typing import Any, Dict, Optional, Tuple, Type from typing import Any, Dict, Optional, Tuple, Type, Union
import llnl.util.filesystem import llnl.util.filesystem
from llnl.url import allowed_archive from llnl.url import allowed_archive
@ -65,6 +65,9 @@ def apply_patch(
patch(*args) patch(*args)
PatchPackageType = Union["spack.package_base.PackageBase", Type["spack.package_base.PackageBase"]]
class Patch: class Patch:
"""Base class for patches. """Base class for patches.
@ -77,7 +80,7 @@ class Patch:
def __init__( def __init__(
self, self,
pkg: "spack.package_base.PackageBase", pkg: PatchPackageType,
path_or_url: str, path_or_url: str,
level: int, level: int,
working_dir: str, working_dir: str,
@ -159,7 +162,7 @@ class FilePatch(Patch):
def __init__( def __init__(
self, self,
pkg: "spack.package_base.PackageBase", pkg: PatchPackageType,
relative_path: str, relative_path: str,
level: int, level: int,
working_dir: str, working_dir: str,
@ -183,7 +186,7 @@ def __init__(
abs_path: Optional[str] = None abs_path: Optional[str] = None
# At different times we call FilePatch on instances and classes # At different times we call FilePatch on instances and classes
pkg_cls = pkg if inspect.isclass(pkg) else pkg.__class__ pkg_cls = pkg if inspect.isclass(pkg) else pkg.__class__
for cls in inspect.getmro(pkg_cls): for cls in inspect.getmro(pkg_cls): # type: ignore
if not hasattr(cls, "module"): if not hasattr(cls, "module"):
# We've gone too far up the MRO # We've gone too far up the MRO
break break
@ -242,7 +245,7 @@ class UrlPatch(Patch):
def __init__( def __init__(
self, self,
pkg: "spack.package_base.PackageBase", pkg: PatchPackageType,
url: str, url: str,
level: int = 1, level: int = 1,
*, *,
@ -361,8 +364,9 @@ def from_dict(
""" """
repository = repository or spack.repo.PATH repository = repository or spack.repo.PATH
owner = dictionary.get("owner") owner = dictionary.get("owner")
if "owner" not in dictionary: if owner is None:
raise ValueError("Invalid patch dictionary: %s" % dictionary) raise ValueError(f"Invalid patch dictionary: {dictionary}")
assert isinstance(owner, str)
pkg_cls = repository.get_pkg_class(owner) pkg_cls = repository.get_pkg_class(owner)
if "url" in dictionary: if "url" in dictionary:

View File

@ -675,15 +675,22 @@ class RepoPath:
repository. repository.
Args: Args:
repos (list): list Repo objects or paths to put in this RepoPath repos: list Repo objects or paths to put in this RepoPath
cache: file cache associated with this repository
overrides: dict mapping package name to class attribute overrides for that package
""" """
def __init__(self, *repos, cache, overrides=None): def __init__(
self.repos = [] self,
*repos: Union[str, "Repo"],
cache: spack.caches.FileCacheType,
overrides: Optional[Dict[str, Any]] = None,
) -> None:
self.repos: List[Repo] = []
self.by_namespace = nm.NamespaceTrie() self.by_namespace = nm.NamespaceTrie()
self._provider_index = None self._provider_index: Optional[spack.provider_index.ProviderIndex] = None
self._patch_index = None self._patch_index: Optional[spack.patch.PatchCache] = None
self._tag_index = None self._tag_index: Optional[spack.tag.TagIndex] = None
# Add each repo to this path. # Add each repo to this path.
for repo in repos: for repo in repos:
@ -694,13 +701,13 @@ def __init__(self, *repos, cache, overrides=None):
self.put_last(repo) self.put_last(repo)
except RepoError as e: except RepoError as e:
tty.warn( tty.warn(
"Failed to initialize repository: '%s'." % repo, f"Failed to initialize repository: '{repo}'.",
e.message, e.message,
"To remove the bad repository, run this command:", "To remove the bad repository, run this command:",
" spack repo rm %s" % repo, f" spack repo rm {repo}",
) )
def put_first(self, repo): def put_first(self, repo: "Repo") -> None:
"""Add repo first in the search path.""" """Add repo first in the search path."""
if isinstance(repo, RepoPath): if isinstance(repo, RepoPath):
for r in reversed(repo.repos): for r in reversed(repo.repos):
@ -728,50 +735,34 @@ def remove(self, repo):
if repo in self.repos: if repo in self.repos:
self.repos.remove(repo) self.repos.remove(repo)
def get_repo(self, namespace, default=NOT_PROVIDED): def get_repo(self, namespace: str) -> "Repo":
"""Get a repository by namespace. """Get a repository by namespace."""
Arguments:
namespace:
Look up this namespace in the RepoPath, and return it if found.
Optional Arguments:
default:
If default is provided, return it when the namespace
isn't found. If not, raise an UnknownNamespaceError.
"""
full_namespace = python_package_for_repo(namespace) full_namespace = python_package_for_repo(namespace)
if full_namespace not in self.by_namespace: if full_namespace not in self.by_namespace:
if default == NOT_PROVIDED: raise UnknownNamespaceError(namespace)
raise UnknownNamespaceError(namespace)
return default
return self.by_namespace[full_namespace] return self.by_namespace[full_namespace]
def first_repo(self): def first_repo(self) -> Optional["Repo"]:
"""Get the first repo in precedence order.""" """Get the first repo in precedence order."""
return self.repos[0] if self.repos else None return self.repos[0] if self.repos else None
@llnl.util.lang.memoized @llnl.util.lang.memoized
def _all_package_names_set(self, include_virtuals): def _all_package_names_set(self, include_virtuals) -> Set[str]:
return {name for repo in self.repos for name in repo.all_package_names(include_virtuals)} return {name for repo in self.repos for name in repo.all_package_names(include_virtuals)}
@llnl.util.lang.memoized @llnl.util.lang.memoized
def _all_package_names(self, include_virtuals): def _all_package_names(self, include_virtuals: bool) -> List[str]:
"""Return all unique package names in all repositories.""" """Return all unique package names in all repositories."""
return sorted(self._all_package_names_set(include_virtuals), key=lambda n: n.lower()) return sorted(self._all_package_names_set(include_virtuals), key=lambda n: n.lower())
def all_package_names(self, include_virtuals=False): def all_package_names(self, include_virtuals: bool = False) -> List[str]:
return self._all_package_names(include_virtuals) return self._all_package_names(include_virtuals)
def package_path(self, name): def package_path(self, name: str) -> str:
"""Get path to package.py file for this repo.""" """Get path to package.py file for this repo."""
return self.repo_for_pkg(name).package_path(name) return self.repo_for_pkg(name).package_path(name)
def all_package_paths(self): def all_package_paths(self) -> Generator[str, None, None]:
for name in self.all_package_names(): for name in self.all_package_names():
yield self.package_path(name) yield self.package_path(name)
@ -787,53 +778,52 @@ def packages_with_tags(self, *tags: str, full: bool = False) -> Set[str]:
for pkg in repo.packages_with_tags(*tags) for pkg in repo.packages_with_tags(*tags)
} }
def all_package_classes(self): def all_package_classes(self) -> Generator[Type["spack.package_base.PackageBase"], None, None]:
for name in self.all_package_names(): for name in self.all_package_names():
yield self.get_pkg_class(name) yield self.get_pkg_class(name)
@property @property
def provider_index(self): def provider_index(self) -> spack.provider_index.ProviderIndex:
"""Merged ProviderIndex from all Repos in the RepoPath.""" """Merged ProviderIndex from all Repos in the RepoPath."""
if self._provider_index is None: if self._provider_index is None:
self._provider_index = spack.provider_index.ProviderIndex(repository=self) self._provider_index = spack.provider_index.ProviderIndex(repository=self)
for repo in reversed(self.repos): for repo in reversed(self.repos):
self._provider_index.merge(repo.provider_index) self._provider_index.merge(repo.provider_index)
return self._provider_index return self._provider_index
@property @property
def tag_index(self): def tag_index(self) -> spack.tag.TagIndex:
"""Merged TagIndex from all Repos in the RepoPath.""" """Merged TagIndex from all Repos in the RepoPath."""
if self._tag_index is None: if self._tag_index is None:
self._tag_index = spack.tag.TagIndex(repository=self) self._tag_index = spack.tag.TagIndex(repository=self)
for repo in reversed(self.repos): for repo in reversed(self.repos):
self._tag_index.merge(repo.tag_index) self._tag_index.merge(repo.tag_index)
return self._tag_index return self._tag_index
@property @property
def patch_index(self): def patch_index(self) -> spack.patch.PatchCache:
"""Merged PatchIndex from all Repos in the RepoPath.""" """Merged PatchIndex from all Repos in the RepoPath."""
if self._patch_index is None: if self._patch_index is None:
self._patch_index = spack.patch.PatchCache(repository=self) self._patch_index = spack.patch.PatchCache(repository=self)
for repo in reversed(self.repos): for repo in reversed(self.repos):
self._patch_index.update(repo.patch_index) self._patch_index.update(repo.patch_index)
return self._patch_index return self._patch_index
@autospec @autospec
def providers_for(self, vpkg_spec): def providers_for(self, virtual_spec: "spack.spec.Spec") -> List["spack.spec.Spec"]:
providers = [ providers = [
spec spec
for spec in self.provider_index.providers_for(vpkg_spec) for spec in self.provider_index.providers_for(virtual_spec)
if spec.name in self._all_package_names_set(include_virtuals=False) if spec.name in self._all_package_names_set(include_virtuals=False)
] ]
if not providers: if not providers:
raise UnknownPackageError(vpkg_spec.fullname) raise UnknownPackageError(virtual_spec.fullname)
return providers return providers
@autospec @autospec
def extensions_for(self, extendee_spec): def extensions_for(
self, extendee_spec: "spack.spec.Spec"
) -> List["spack.package_base.PackageBase"]:
return [ return [
pkg_cls(spack.spec.Spec(pkg_cls.name)) pkg_cls(spack.spec.Spec(pkg_cls.name))
for pkg_cls in self.all_package_classes() for pkg_cls in self.all_package_classes()
@ -844,7 +834,7 @@ def last_mtime(self):
"""Time a package file in this repo was last updated.""" """Time a package file in this repo was last updated."""
return max(repo.last_mtime() for repo in self.repos) return max(repo.last_mtime() for repo in self.repos)
def repo_for_pkg(self, spec): def repo_for_pkg(self, spec: Union[str, "spack.spec.Spec"]) -> "Repo":
"""Given a spec, get the repository for its package.""" """Given a spec, get the repository for its package."""
# We don't @_autospec this function b/c it's called very frequently # We don't @_autospec this function b/c it's called very frequently
# and we want to avoid parsing str's into Specs unnecessarily. # and we want to avoid parsing str's into Specs unnecessarily.
@ -869,17 +859,20 @@ def repo_for_pkg(self, spec):
return repo return repo
# If the package isn't in any repo, return the one with # If the package isn't in any repo, return the one with
# highest precedence. This is for commands like `spack edit` # highest precedence. This is for commands like `spack edit`
# that can operate on packages that don't exist yet. # that can operate on packages that don't exist yet.
return self.first_repo() selected = self.first_repo()
if selected is None:
raise UnknownPackageError(name)
return selected
def get(self, spec): def get(self, spec: "spack.spec.Spec") -> "spack.package_base.PackageBase":
"""Returns the package associated with the supplied spec.""" """Returns the package associated with the supplied spec."""
msg = "RepoPath.get can only be called on concrete specs" msg = "RepoPath.get can only be called on concrete specs"
assert isinstance(spec, spack.spec.Spec) and spec.concrete, msg assert isinstance(spec, spack.spec.Spec) and spec.concrete, msg
return self.repo_for_pkg(spec).get(spec) return self.repo_for_pkg(spec).get(spec)
def get_pkg_class(self, pkg_name): def get_pkg_class(self, pkg_name: str) -> Type["spack.package_base.PackageBase"]:
"""Find a class for the spec's package and return the class object.""" """Find a class for the spec's package and return the class object."""
return self.repo_for_pkg(pkg_name).get_pkg_class(pkg_name) return self.repo_for_pkg(pkg_name).get_pkg_class(pkg_name)
@ -892,26 +885,26 @@ def dump_provenance(self, spec, path):
""" """
return self.repo_for_pkg(spec).dump_provenance(spec, path) return self.repo_for_pkg(spec).dump_provenance(spec, path)
def dirname_for_package_name(self, pkg_name): def dirname_for_package_name(self, pkg_name: str) -> str:
return self.repo_for_pkg(pkg_name).dirname_for_package_name(pkg_name) return self.repo_for_pkg(pkg_name).dirname_for_package_name(pkg_name)
def filename_for_package_name(self, pkg_name): def filename_for_package_name(self, pkg_name: str) -> str:
return self.repo_for_pkg(pkg_name).filename_for_package_name(pkg_name) return self.repo_for_pkg(pkg_name).filename_for_package_name(pkg_name)
def exists(self, pkg_name): def exists(self, pkg_name: str) -> bool:
"""Whether package with the give name exists in the path's repos. """Whether package with the give name exists in the path's repos.
Note that virtual packages do not "exist". Note that virtual packages do not "exist".
""" """
return any(repo.exists(pkg_name) for repo in self.repos) return any(repo.exists(pkg_name) for repo in self.repos)
def _have_name(self, pkg_name): def _have_name(self, pkg_name: str) -> bool:
have_name = pkg_name is not None have_name = pkg_name is not None
if have_name and not isinstance(pkg_name, str): if have_name and not isinstance(pkg_name, str):
raise ValueError("is_virtual(): expected package name, got %s" % type(pkg_name)) raise ValueError(f"is_virtual(): expected package name, got {type(pkg_name)}")
return have_name return have_name
def is_virtual(self, pkg_name): def is_virtual(self, pkg_name: str) -> bool:
"""Return True if the package with this name is virtual, False otherwise. """Return True if the package with this name is virtual, False otherwise.
This function use the provider index. If calling from a code block that This function use the provider index. If calling from a code block that
@ -923,7 +916,7 @@ def is_virtual(self, pkg_name):
have_name = self._have_name(pkg_name) have_name = self._have_name(pkg_name)
return have_name and pkg_name in self.provider_index return have_name and pkg_name in self.provider_index
def is_virtual_safe(self, pkg_name): def is_virtual_safe(self, pkg_name: str) -> bool:
"""Return True if the package with this name is virtual, False otherwise. """Return True if the package with this name is virtual, False otherwise.
This function doesn't use the provider index. This function doesn't use the provider index.
@ -1418,7 +1411,9 @@ def _path(configuration=None):
return create(configuration=configuration) return create(configuration=configuration)
def create(configuration): def create(
configuration: Union["spack.config.Configuration", llnl.util.lang.Singleton]
) -> RepoPath:
"""Create a RepoPath from a configuration object. """Create a RepoPath from a configuration object.
Args: Args:
@ -1454,20 +1449,20 @@ def all_package_names(include_virtuals=False):
@contextlib.contextmanager @contextlib.contextmanager
def use_repositories(*paths_and_repos, **kwargs): def use_repositories(
*paths_and_repos: Union[str, Repo], override: bool = True
) -> Generator[RepoPath, None, None]:
"""Use the repositories passed as arguments within the context manager. """Use the repositories passed as arguments within the context manager.
Args: Args:
*paths_and_repos: paths to the repositories to be used, or *paths_and_repos: paths to the repositories to be used, or
already constructed Repo objects already constructed Repo objects
override (bool): if True use only the repositories passed as input, override: if True use only the repositories passed as input,
if False add them to the top of the list of current repositories. if False add them to the top of the list of current repositories.
Returns: Returns:
Corresponding RepoPath object Corresponding RepoPath object
""" """
global PATH global PATH
# TODO (Python 2.7): remove this kwargs on deprecation of Python 2.7 support
override = kwargs.get("override", True)
paths = [getattr(x, "root", x) for x in paths_and_repos] paths = [getattr(x, "root", x) for x in paths_and_repos]
scope_name = "use-repo-{}".format(uuid.uuid4()) scope_name = "use-repo-{}".format(uuid.uuid4())
repos_key = "repos:" if override else "repos" repos_key = "repos:" if override else "repos"
@ -1476,7 +1471,8 @@ def use_repositories(*paths_and_repos, **kwargs):
) )
PATH, saved = create(configuration=spack.config.CONFIG), PATH PATH, saved = create(configuration=spack.config.CONFIG), PATH
try: try:
yield PATH with REPOS_FINDER.switch_repo(PATH): # type: ignore
yield PATH
finally: finally:
spack.config.CONFIG.remove_scope(scope_name=scope_name) spack.config.CONFIG.remove_scope(scope_name=scope_name)
PATH = saved PATH = saved
@ -1576,10 +1572,9 @@ class UnknownNamespaceError(UnknownEntityError):
"""Raised when we encounter an unknown namespace""" """Raised when we encounter an unknown namespace"""
def __init__(self, namespace, name=None): def __init__(self, namespace, name=None):
msg, long_msg = "Unknown namespace: {}".format(namespace), None msg, long_msg = f"Unknown namespace: {namespace}", None
if name == "yaml": if name == "yaml":
long_msg = "Did you mean to specify a filename with './{}.{}'?" long_msg = f"Did you mean to specify a filename with './{namespace}.{name}'?"
long_msg = long_msg.format(namespace, name)
super().__init__(msg, long_msg) super().__init__(msg, long_msg)

View File

@ -2069,4 +2069,5 @@ def _c_compiler_always_exists():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def mock_test_cache(tmp_path_factory): def mock_test_cache(tmp_path_factory):
cache_dir = tmp_path_factory.mktemp("cache") cache_dir = tmp_path_factory.mktemp("cache")
print(cache_dir)
return spack.util.file_cache.FileCache(str(cache_dir)) return spack.util.file_cache.FileCache(str(cache_dir))

View File

@ -3,6 +3,7 @@
# #
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
import os import os
import pathlib
import pytest import pytest
@ -205,6 +206,18 @@ def test_path_computation_with_names(method_name, mock_repo_path):
assert qualified == unqualified assert qualified == unqualified
def test_use_repositories_and_import():
"""Tests that use_repositories changes the import search too"""
import spack.paths
repo_dir = pathlib.Path(spack.paths.repos_path)
with spack.repo.use_repositories(str(repo_dir / "compiler_runtime.test")):
import spack.pkg.compiler_runtime.test.gcc_runtime
with spack.repo.use_repositories(str(repo_dir / "builtin.mock")):
import spack.pkg.builtin.mock.cmake
@pytest.mark.usefixtures("nullify_globals") @pytest.mark.usefixtures("nullify_globals")
class TestRepo: class TestRepo:
"""Test that the Repo class work correctly, and does not depend on globals, """Test that the Repo class work correctly, and does not depend on globals,
@ -219,8 +232,9 @@ def test_creation(self, mock_test_cache):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"name,expected", [("mpi", True), ("mpich", False), ("mpileaks", False)] "name,expected", [("mpi", True), ("mpich", False), ("mpileaks", False)]
) )
def test_is_virtual(self, name, expected, mock_test_cache): @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache) def test_is_virtual(self, repo_cls, name, expected, mock_test_cache):
repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
assert repo.is_virtual(name) is expected assert repo.is_virtual(name) is expected
assert repo.is_virtual_safe(name) is expected assert repo.is_virtual_safe(name) is expected
@ -258,13 +272,15 @@ def test_providers(self, virtual_name, expected, mock_test_cache):
"extended,expected", "extended,expected",
[("python", ["py-extension1", "python-venv"]), ("perl", ["perl-extension"])], [("python", ["py-extension1", "python-venv"]), ("perl", ["perl-extension"])],
) )
def test_extensions(self, extended, expected, mock_test_cache): @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache) def test_extensions(self, repo_cls, extended, expected, mock_test_cache):
repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
provider_names = {x.name for x in repo.extensions_for(extended)} provider_names = {x.name for x in repo.extensions_for(extended)}
assert provider_names.issuperset(expected) assert provider_names.issuperset(expected)
def test_all_package_names(self, mock_test_cache): @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache) def test_all_package_names(self, repo_cls, mock_test_cache):
repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
all_names = repo.all_package_names(include_virtuals=True) all_names = repo.all_package_names(include_virtuals=True)
real_names = repo.all_package_names(include_virtuals=False) real_names = repo.all_package_names(include_virtuals=False)
assert set(all_names).issuperset(real_names) assert set(all_names).issuperset(real_names)
@ -272,10 +288,28 @@ def test_all_package_names(self, mock_test_cache):
assert repo.is_virtual(name) assert repo.is_virtual(name)
assert repo.is_virtual_safe(name) assert repo.is_virtual_safe(name)
def test_packages_with_tags(self, mock_test_cache): @pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache) def test_packages_with_tags(self, repo_cls, mock_test_cache):
repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
r1 = repo.packages_with_tags("tag1") r1 = repo.packages_with_tags("tag1")
r2 = repo.packages_with_tags("tag1", "tag2") r2 = repo.packages_with_tags("tag1", "tag2")
assert "mpich" in r1 and "mpich" in r2 assert "mpich" in r1 and "mpich" in r2
assert "mpich2" in r1 and "mpich2" not in r2 assert "mpich2" in r1 and "mpich2" not in r2
assert set(r2).issubset(r1) assert set(r2).issubset(r1)
@pytest.mark.usefixtures("nullify_globals")
class TestRepoPath:
def test_creation_from_string(self, mock_test_cache):
repo = spack.repo.RepoPath(spack.paths.mock_packages_path, cache=mock_test_cache)
assert len(repo.repos) == 1
assert repo.repos[0]._finder is repo
assert repo.by_namespace["spack.pkg.builtin.mock"] is repo.repos[0]
def test_get_repo(self, mock_test_cache):
repo = spack.repo.RepoPath(spack.paths.mock_packages_path, cache=mock_test_cache)
# builtin.mock is there
assert repo.get_repo("builtin.mock") is repo.repos[0]
# foo is not there, raise
with pytest.raises(spack.repo.UnknownNamespaceError):
repo.get_repo("foo")