Add type-hints to RepoPath (#45068)

* Also, fix a bug with use_repositories + import spack.pkg
This commit is contained in:
Massimiliano Culpo 2024-07-08 11:48:39 +02:00 committed by GitHub
parent cef9c36183
commit 74398d74ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 121 additions and 89 deletions

View File

@ -941,9 +941,7 @@ def get_repository(args, name):
)
else:
if spec.namespace:
repo = spack.repo.PATH.get_repo(spec.namespace, None)
if not repo:
tty.die("Unknown namespace: '{0}'".format(spec.namespace))
repo = spack.repo.PATH.get_repo(spec.namespace)
else:
repo = spack.repo.PATH.first_repo()

View File

@ -12,7 +12,7 @@
import re
import sys
import warnings
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Type
import llnl.util.filesystem
import llnl.util.lang
@ -200,7 +200,7 @@ class Finder:
def default_path_hints(self) -> List[str]:
return []
def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]:
def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
"""Returns the list of patterns used to match candidate files.
Args:
@ -226,7 +226,7 @@ def prefix_from_path(self, *, path: str) -> str:
raise NotImplementedError("must be implemented by derived classes")
def detect_specs(
self, *, pkg: "spack.package_base.PackageBase", paths: List[str]
self, *, pkg: Type["spack.package_base.PackageBase"], paths: List[str]
) -> List[DetectedPackage]:
"""Given a list of files matching the search patterns, returns a list of detected specs.
@ -327,7 +327,7 @@ class ExecutablesFinder(Finder):
def default_path_hints(self) -> List[str]:
return spack.util.environment.get_path("PATH")
def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]:
def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
result = []
if hasattr(pkg, "executables") and hasattr(pkg, "platform_executables"):
result = pkg.platform_executables()
@ -356,7 +356,7 @@ class LibrariesFinder(Finder):
DYLD_LIBRARY_PATH, DYLD_FALLBACK_LIBRARY_PATH, and standard system library paths
"""
def search_patterns(self, *, pkg: "spack.package_base.PackageBase") -> List[str]:
def search_patterns(self, *, pkg: Type["spack.package_base.PackageBase"]) -> List[str]:
result = []
if hasattr(pkg, "libraries"):
result = pkg.libraries

View File

@ -9,7 +9,7 @@
import os.path
import pathlib
import sys
from typing import Any, Dict, Optional, Tuple, Type
from typing import Any, Dict, Optional, Tuple, Type, Union
import llnl.util.filesystem
from llnl.url import allowed_archive
@ -65,6 +65,9 @@ def apply_patch(
patch(*args)
PatchPackageType = Union["spack.package_base.PackageBase", Type["spack.package_base.PackageBase"]]
class Patch:
"""Base class for patches.
@ -77,7 +80,7 @@ class Patch:
def __init__(
self,
pkg: "spack.package_base.PackageBase",
pkg: PatchPackageType,
path_or_url: str,
level: int,
working_dir: str,
@ -159,7 +162,7 @@ class FilePatch(Patch):
def __init__(
self,
pkg: "spack.package_base.PackageBase",
pkg: PatchPackageType,
relative_path: str,
level: int,
working_dir: str,
@ -183,7 +186,7 @@ def __init__(
abs_path: Optional[str] = None
# At different times we call FilePatch on instances and classes
pkg_cls = pkg if inspect.isclass(pkg) else pkg.__class__
for cls in inspect.getmro(pkg_cls):
for cls in inspect.getmro(pkg_cls): # type: ignore
if not hasattr(cls, "module"):
# We've gone too far up the MRO
break
@ -242,7 +245,7 @@ class UrlPatch(Patch):
def __init__(
self,
pkg: "spack.package_base.PackageBase",
pkg: PatchPackageType,
url: str,
level: int = 1,
*,
@ -361,8 +364,9 @@ def from_dict(
"""
repository = repository or spack.repo.PATH
owner = dictionary.get("owner")
if "owner" not in dictionary:
raise ValueError("Invalid patch dictionary: %s" % dictionary)
if owner is None:
raise ValueError(f"Invalid patch dictionary: {dictionary}")
assert isinstance(owner, str)
pkg_cls = repository.get_pkg_class(owner)
if "url" in dictionary:

View File

@ -675,15 +675,22 @@ class RepoPath:
repository.
Args:
repos (list): list Repo objects or paths to put in this RepoPath
repos: list Repo objects or paths to put in this RepoPath
cache: file cache associated with this repository
overrides: dict mapping package name to class attribute overrides for that package
"""
def __init__(self, *repos, cache, overrides=None):
self.repos = []
def __init__(
self,
*repos: Union[str, "Repo"],
cache: spack.caches.FileCacheType,
overrides: Optional[Dict[str, Any]] = None,
) -> None:
self.repos: List[Repo] = []
self.by_namespace = nm.NamespaceTrie()
self._provider_index = None
self._patch_index = None
self._tag_index = None
self._provider_index: Optional[spack.provider_index.ProviderIndex] = None
self._patch_index: Optional[spack.patch.PatchCache] = None
self._tag_index: Optional[spack.tag.TagIndex] = None
# Add each repo to this path.
for repo in repos:
@ -694,13 +701,13 @@ def __init__(self, *repos, cache, overrides=None):
self.put_last(repo)
except RepoError as e:
tty.warn(
"Failed to initialize repository: '%s'." % repo,
f"Failed to initialize repository: '{repo}'.",
e.message,
"To remove the bad repository, run this command:",
" spack repo rm %s" % repo,
f" spack repo rm {repo}",
)
def put_first(self, repo):
def put_first(self, repo: "Repo") -> None:
"""Add repo first in the search path."""
if isinstance(repo, RepoPath):
for r in reversed(repo.repos):
@ -728,50 +735,34 @@ def remove(self, repo):
if repo in self.repos:
self.repos.remove(repo)
def get_repo(self, namespace, default=NOT_PROVIDED):
"""Get a repository by namespace.
Arguments:
namespace:
Look up this namespace in the RepoPath, and return it if found.
Optional Arguments:
default:
If default is provided, return it when the namespace
isn't found. If not, raise an UnknownNamespaceError.
"""
def get_repo(self, namespace: str) -> "Repo":
"""Get a repository by namespace."""
full_namespace = python_package_for_repo(namespace)
if full_namespace not in self.by_namespace:
if default == NOT_PROVIDED:
raise UnknownNamespaceError(namespace)
return default
raise UnknownNamespaceError(namespace)
return self.by_namespace[full_namespace]
def first_repo(self):
def first_repo(self) -> Optional["Repo"]:
"""Get the first repo in precedence order."""
return self.repos[0] if self.repos else None
@llnl.util.lang.memoized
def _all_package_names_set(self, include_virtuals):
def _all_package_names_set(self, include_virtuals) -> Set[str]:
return {name for repo in self.repos for name in repo.all_package_names(include_virtuals)}
@llnl.util.lang.memoized
def _all_package_names(self, include_virtuals):
def _all_package_names(self, include_virtuals: bool) -> List[str]:
"""Return all unique package names in all repositories."""
return sorted(self._all_package_names_set(include_virtuals), key=lambda n: n.lower())
def all_package_names(self, include_virtuals=False):
def all_package_names(self, include_virtuals: bool = False) -> List[str]:
return self._all_package_names(include_virtuals)
def package_path(self, name):
def package_path(self, name: str) -> str:
"""Get path to package.py file for this repo."""
return self.repo_for_pkg(name).package_path(name)
def all_package_paths(self):
def all_package_paths(self) -> Generator[str, None, None]:
for name in self.all_package_names():
yield self.package_path(name)
@ -787,53 +778,52 @@ def packages_with_tags(self, *tags: str, full: bool = False) -> Set[str]:
for pkg in repo.packages_with_tags(*tags)
}
def all_package_classes(self):
def all_package_classes(self) -> Generator[Type["spack.package_base.PackageBase"], None, None]:
for name in self.all_package_names():
yield self.get_pkg_class(name)
@property
def provider_index(self):
def provider_index(self) -> spack.provider_index.ProviderIndex:
"""Merged ProviderIndex from all Repos in the RepoPath."""
if self._provider_index is None:
self._provider_index = spack.provider_index.ProviderIndex(repository=self)
for repo in reversed(self.repos):
self._provider_index.merge(repo.provider_index)
return self._provider_index
@property
def tag_index(self):
def tag_index(self) -> spack.tag.TagIndex:
"""Merged TagIndex from all Repos in the RepoPath."""
if self._tag_index is None:
self._tag_index = spack.tag.TagIndex(repository=self)
for repo in reversed(self.repos):
self._tag_index.merge(repo.tag_index)
return self._tag_index
@property
def patch_index(self):
def patch_index(self) -> spack.patch.PatchCache:
"""Merged PatchIndex from all Repos in the RepoPath."""
if self._patch_index is None:
self._patch_index = spack.patch.PatchCache(repository=self)
for repo in reversed(self.repos):
self._patch_index.update(repo.patch_index)
return self._patch_index
@autospec
def providers_for(self, vpkg_spec):
def providers_for(self, virtual_spec: "spack.spec.Spec") -> List["spack.spec.Spec"]:
providers = [
spec
for spec in self.provider_index.providers_for(vpkg_spec)
for spec in self.provider_index.providers_for(virtual_spec)
if spec.name in self._all_package_names_set(include_virtuals=False)
]
if not providers:
raise UnknownPackageError(vpkg_spec.fullname)
raise UnknownPackageError(virtual_spec.fullname)
return providers
@autospec
def extensions_for(self, extendee_spec):
def extensions_for(
self, extendee_spec: "spack.spec.Spec"
) -> List["spack.package_base.PackageBase"]:
return [
pkg_cls(spack.spec.Spec(pkg_cls.name))
for pkg_cls in self.all_package_classes()
@ -844,7 +834,7 @@ def last_mtime(self):
"""Time a package file in this repo was last updated."""
return max(repo.last_mtime() for repo in self.repos)
def repo_for_pkg(self, spec):
def repo_for_pkg(self, spec: Union[str, "spack.spec.Spec"]) -> "Repo":
"""Given a spec, get the repository for its package."""
# We don't @_autospec this function b/c it's called very frequently
# and we want to avoid parsing str's into Specs unnecessarily.
@ -869,17 +859,20 @@ def repo_for_pkg(self, spec):
return repo
# If the package isn't in any repo, return the one with
# highest precedence. This is for commands like `spack edit`
# highest precedence. This is for commands like `spack edit`
# that can operate on packages that don't exist yet.
return self.first_repo()
selected = self.first_repo()
if selected is None:
raise UnknownPackageError(name)
return selected
def get(self, spec):
def get(self, spec: "spack.spec.Spec") -> "spack.package_base.PackageBase":
"""Returns the package associated with the supplied spec."""
msg = "RepoPath.get can only be called on concrete specs"
assert isinstance(spec, spack.spec.Spec) and spec.concrete, msg
return self.repo_for_pkg(spec).get(spec)
def get_pkg_class(self, pkg_name):
def get_pkg_class(self, pkg_name: str) -> Type["spack.package_base.PackageBase"]:
"""Find a class for the spec's package and return the class object."""
return self.repo_for_pkg(pkg_name).get_pkg_class(pkg_name)
@ -892,26 +885,26 @@ def dump_provenance(self, spec, path):
"""
return self.repo_for_pkg(spec).dump_provenance(spec, path)
def dirname_for_package_name(self, pkg_name):
def dirname_for_package_name(self, pkg_name: str) -> str:
return self.repo_for_pkg(pkg_name).dirname_for_package_name(pkg_name)
def filename_for_package_name(self, pkg_name):
def filename_for_package_name(self, pkg_name: str) -> str:
return self.repo_for_pkg(pkg_name).filename_for_package_name(pkg_name)
def exists(self, pkg_name):
def exists(self, pkg_name: str) -> bool:
"""Whether package with the give name exists in the path's repos.
Note that virtual packages do not "exist".
"""
return any(repo.exists(pkg_name) for repo in self.repos)
def _have_name(self, pkg_name):
def _have_name(self, pkg_name: str) -> bool:
have_name = pkg_name is not None
if have_name and not isinstance(pkg_name, str):
raise ValueError("is_virtual(): expected package name, got %s" % type(pkg_name))
raise ValueError(f"is_virtual(): expected package name, got {type(pkg_name)}")
return have_name
def is_virtual(self, pkg_name):
def is_virtual(self, pkg_name: str) -> bool:
"""Return True if the package with this name is virtual, False otherwise.
This function use the provider index. If calling from a code block that
@ -923,7 +916,7 @@ def is_virtual(self, pkg_name):
have_name = self._have_name(pkg_name)
return have_name and pkg_name in self.provider_index
def is_virtual_safe(self, pkg_name):
def is_virtual_safe(self, pkg_name: str) -> bool:
"""Return True if the package with this name is virtual, False otherwise.
This function doesn't use the provider index.
@ -1418,7 +1411,9 @@ def _path(configuration=None):
return create(configuration=configuration)
def create(configuration):
def create(
configuration: Union["spack.config.Configuration", llnl.util.lang.Singleton]
) -> RepoPath:
"""Create a RepoPath from a configuration object.
Args:
@ -1454,20 +1449,20 @@ def all_package_names(include_virtuals=False):
@contextlib.contextmanager
def use_repositories(*paths_and_repos, **kwargs):
def use_repositories(
*paths_and_repos: Union[str, Repo], override: bool = True
) -> Generator[RepoPath, None, None]:
"""Use the repositories passed as arguments within the context manager.
Args:
*paths_and_repos: paths to the repositories to be used, or
already constructed Repo objects
override (bool): if True use only the repositories passed as input,
override: if True use only the repositories passed as input,
if False add them to the top of the list of current repositories.
Returns:
Corresponding RepoPath object
"""
global PATH
# TODO (Python 2.7): remove this kwargs on deprecation of Python 2.7 support
override = kwargs.get("override", True)
paths = [getattr(x, "root", x) for x in paths_and_repos]
scope_name = "use-repo-{}".format(uuid.uuid4())
repos_key = "repos:" if override else "repos"
@ -1476,7 +1471,8 @@ def use_repositories(*paths_and_repos, **kwargs):
)
PATH, saved = create(configuration=spack.config.CONFIG), PATH
try:
yield PATH
with REPOS_FINDER.switch_repo(PATH): # type: ignore
yield PATH
finally:
spack.config.CONFIG.remove_scope(scope_name=scope_name)
PATH = saved
@ -1576,10 +1572,9 @@ class UnknownNamespaceError(UnknownEntityError):
"""Raised when we encounter an unknown namespace"""
def __init__(self, namespace, name=None):
msg, long_msg = "Unknown namespace: {}".format(namespace), None
msg, long_msg = f"Unknown namespace: {namespace}", None
if name == "yaml":
long_msg = "Did you mean to specify a filename with './{}.{}'?"
long_msg = long_msg.format(namespace, name)
long_msg = f"Did you mean to specify a filename with './{namespace}.{name}'?"
super().__init__(msg, long_msg)

View File

@ -2069,4 +2069,5 @@ def _c_compiler_always_exists():
@pytest.fixture(scope="session")
def mock_test_cache(tmp_path_factory):
cache_dir = tmp_path_factory.mktemp("cache")
print(cache_dir)
return spack.util.file_cache.FileCache(str(cache_dir))

View File

@ -3,6 +3,7 @@
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import os
import pathlib
import pytest
@ -205,6 +206,18 @@ def test_path_computation_with_names(method_name, mock_repo_path):
assert qualified == unqualified
def test_use_repositories_and_import():
"""Tests that use_repositories changes the import search too"""
import spack.paths
repo_dir = pathlib.Path(spack.paths.repos_path)
with spack.repo.use_repositories(str(repo_dir / "compiler_runtime.test")):
import spack.pkg.compiler_runtime.test.gcc_runtime
with spack.repo.use_repositories(str(repo_dir / "builtin.mock")):
import spack.pkg.builtin.mock.cmake
@pytest.mark.usefixtures("nullify_globals")
class TestRepo:
"""Test that the Repo class work correctly, and does not depend on globals,
@ -219,8 +232,9 @@ def test_creation(self, mock_test_cache):
@pytest.mark.parametrize(
"name,expected", [("mpi", True), ("mpich", False), ("mpileaks", False)]
)
def test_is_virtual(self, name, expected, mock_test_cache):
repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache)
@pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
def test_is_virtual(self, repo_cls, name, expected, mock_test_cache):
repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
assert repo.is_virtual(name) is expected
assert repo.is_virtual_safe(name) is expected
@ -258,13 +272,15 @@ def test_providers(self, virtual_name, expected, mock_test_cache):
"extended,expected",
[("python", ["py-extension1", "python-venv"]), ("perl", ["perl-extension"])],
)
def test_extensions(self, extended, expected, mock_test_cache):
repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache)
@pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
def test_extensions(self, repo_cls, extended, expected, mock_test_cache):
repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
provider_names = {x.name for x in repo.extensions_for(extended)}
assert provider_names.issuperset(expected)
def test_all_package_names(self, mock_test_cache):
repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache)
@pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
def test_all_package_names(self, repo_cls, mock_test_cache):
repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
all_names = repo.all_package_names(include_virtuals=True)
real_names = repo.all_package_names(include_virtuals=False)
assert set(all_names).issuperset(real_names)
@ -272,10 +288,28 @@ def test_all_package_names(self, mock_test_cache):
assert repo.is_virtual(name)
assert repo.is_virtual_safe(name)
def test_packages_with_tags(self, mock_test_cache):
repo = spack.repo.Repo(spack.paths.mock_packages_path, cache=mock_test_cache)
@pytest.mark.parametrize("repo_cls", [spack.repo.Repo, spack.repo.RepoPath])
def test_packages_with_tags(self, repo_cls, mock_test_cache):
repo = repo_cls(spack.paths.mock_packages_path, cache=mock_test_cache)
r1 = repo.packages_with_tags("tag1")
r2 = repo.packages_with_tags("tag1", "tag2")
assert "mpich" in r1 and "mpich" in r2
assert "mpich2" in r1 and "mpich2" not in r2
assert set(r2).issubset(r1)
@pytest.mark.usefixtures("nullify_globals")
class TestRepoPath:
def test_creation_from_string(self, mock_test_cache):
repo = spack.repo.RepoPath(spack.paths.mock_packages_path, cache=mock_test_cache)
assert len(repo.repos) == 1
assert repo.repos[0]._finder is repo
assert repo.by_namespace["spack.pkg.builtin.mock"] is repo.repos[0]
def test_get_repo(self, mock_test_cache):
repo = spack.repo.RepoPath(spack.paths.mock_packages_path, cache=mock_test_cache)
# builtin.mock is there
assert repo.get_repo("builtin.mock") is repo.repos[0]
# foo is not there, raise
with pytest.raises(spack.repo.UnknownNamespaceError):
repo.get_repo("foo")