directives: use Type[PackageBase] instead of PackageBase

The first argument to each Spack directive is not a `PackageBase` instance but a
`PackageBase` class object, so fix the type annotations to reflect this.

Signed-off-by: Todd Gamblin <tgamblin@llnl.gov>
This commit is contained in:
Todd Gamblin 2024-11-29 22:55:18 -08:00
parent aa81d59958
commit 175a4bf101
4 changed files with 30 additions and 21 deletions

View File

@ -3,7 +3,7 @@
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
"""Data structures that represent Spack's dependency relationships."""
from typing import Dict, List
from typing import Dict, List, Type
import spack.deptypes as dt
import spack.spec
@ -38,7 +38,7 @@ class Dependency:
def __init__(
self,
pkg: "spack.package_base.PackageBase",
pkg: Type["spack.package_base.PackageBase"],
spec: "spack.spec.Spec",
depflag: dt.DepFlag = dt.DEFAULT,
):

View File

@ -21,6 +21,7 @@ class OpenMpi(Package):
* ``conflicts``
* ``depends_on``
* ``extends``
* ``license``
* ``patch``
* ``provides``
* ``resource``
@ -34,7 +35,7 @@ class OpenMpi(Package):
import collections.abc
import os.path
import re
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Type, Union
import llnl.util.tty.color
@ -81,7 +82,7 @@ class OpenMpi(Package):
SpecType = str
DepType = Union[Tuple[str, ...], str]
WhenType = Optional[Union[spack.spec.Spec, str, bool]]
Patcher = Callable[[Union[spack.package_base.PackageBase, Dependency]], None]
Patcher = Callable[[Union[Type[spack.package_base.PackageBase], Dependency]], None]
PatchesType = Union[Patcher, str, List[Union[Patcher, str]]]
@ -218,7 +219,7 @@ def version(
return lambda pkg: _execute_version(pkg, ver, **kwargs)
def _execute_version(pkg, ver, **kwargs):
def _execute_version(pkg: Type[spack.package_base.PackageBase], ver: Union[str, int], **kwargs):
if (
(any(s in kwargs for s in spack.util.crypto.hashes) or "checksum" in kwargs)
and hasattr(pkg, "has_code")
@ -249,7 +250,7 @@ def _execute_version(pkg, ver, **kwargs):
def _depends_on(
pkg: spack.package_base.PackageBase,
pkg: Type[spack.package_base.PackageBase],
spec: spack.spec.Spec,
*,
when: WhenType = None,
@ -329,7 +330,7 @@ def conflicts(conflict_spec: SpecType, when: WhenType = None, msg: Optional[str]
msg (str): optional user defined message
"""
def _execute_conflicts(pkg: spack.package_base.PackageBase):
def _execute_conflicts(pkg: Type[spack.package_base.PackageBase]):
# If when is not specified the conflict always holds
when_spec = _make_when_spec(when)
if not when_spec:
@ -370,14 +371,16 @@ def depends_on(
assert type == "build", "languages must be of 'build' type"
return _language(lang_spec_str=spec, when=when)
def _execute_depends_on(pkg: spack.package_base.PackageBase):
def _execute_depends_on(pkg: Type[spack.package_base.PackageBase]):
_depends_on(pkg, dep_spec, when=when, type=type, patches=patches)
return _execute_depends_on
@directive("disable_redistribute")
def redistribute(source=None, binary=None, when: WhenType = None):
def redistribute(
source: Optional[bool] = None, binary: Optional[bool] = None, when: WhenType = None
):
"""Can be used inside a Package definition to declare that
the package source and/or compiled binaries should not be
redistributed.
@ -392,7 +395,10 @@ def redistribute(source=None, binary=None, when: WhenType = None):
def _execute_redistribute(
pkg: spack.package_base.PackageBase, source=None, binary=None, when: WhenType = None
pkg: Type[spack.package_base.PackageBase],
source: Optional[bool],
binary: Optional[bool],
when: WhenType,
):
if source is None and binary is None:
return
@ -468,7 +474,7 @@ def provides(*specs: SpecType, when: WhenType = None):
when: condition when this provides clause needs to be considered
"""
def _execute_provides(pkg: spack.package_base.PackageBase):
def _execute_provides(pkg: Type[spack.package_base.PackageBase]):
import spack.parser # Avoid circular dependency
when_spec = _make_when_spec(when)
@ -516,7 +522,7 @@ def can_splice(
variants will be skipped by '*'.
"""
def _execute_can_splice(pkg: spack.package_base.PackageBase):
def _execute_can_splice(pkg: Type[spack.package_base.PackageBase]):
when_spec = _make_when_spec(when)
if isinstance(match_variants, str) and match_variants != "*":
raise ValueError(
@ -557,10 +563,10 @@ def patch(
compressed URL patches)
"""
def _execute_patch(pkg_or_dep: Union[spack.package_base.PackageBase, Dependency]):
pkg = pkg_or_dep
if isinstance(pkg, Dependency):
pkg = pkg.pkg
def _execute_patch(
pkg_or_dep: Union[Type[spack.package_base.PackageBase], Dependency]
) -> None:
pkg = pkg_or_dep.pkg if isinstance(pkg_or_dep, Dependency) else pkg_or_dep
if hasattr(pkg, "has_code") and not pkg.has_code:
raise UnsupportedPackageDirective(
@ -817,7 +823,9 @@ def _execute_maintainer(pkg):
return _execute_maintainer
def _execute_license(pkg, license_identifier: str, when):
def _execute_license(
pkg: Type[spack.package_base.PackageBase], license_identifier: str, when: WhenType
):
# If when is not specified the license always holds
when_spec = _make_when_spec(when)
if not when_spec:
@ -881,7 +889,7 @@ def requires(*requirement_specs: str, policy="one_of", when=None, msg=None):
msg: optional user defined message
"""
def _execute_requires(pkg: spack.package_base.PackageBase):
def _execute_requires(pkg: Type[spack.package_base.PackageBase]):
if policy not in ("one_of", "any_of"):
err_msg = (
f"the 'policy' argument of the 'requires' directive in {pkg.name} is set "
@ -906,7 +914,7 @@ def _execute_requires(pkg: spack.package_base.PackageBase):
def _language(lang_spec_str: str, *, when: Optional[Union[str, bool]] = None):
"""Temporary implementation of language virtuals, until compilers are proper dependencies."""
def _execute_languages(pkg: spack.package_base.PackageBase):
def _execute_languages(pkg: Type[spack.package_base.PackageBase]):
when_spec = _make_when_spec(when)
if not when_spec:
return

View File

@ -595,6 +595,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
patches: Dict[spack.spec.Spec, List[spack.patch.Patch]]
variants: Dict[spack.spec.Spec, Dict[str, spack.variant.Variant]]
languages: Dict[spack.spec.Spec, Set[str]]
licenses: Dict[spack.spec.Spec, str]
splice_specs: Dict[spack.spec.Spec, Tuple[spack.spec.Spec, Union[None, str, List[str]]]]
#: Store whether a given Spec source/binary should not be redistributed.

View File

@ -219,10 +219,10 @@ class MockPackage:
disable_redistribute = {}
cls = MockPackage
spack.directives._execute_redistribute(cls, source=False, when="@1.0")
spack.directives._execute_redistribute(cls, source=False, binary=None, when="@1.0")
spec_key = spack.directives._make_when_spec("@1.0")
assert not cls.disable_redistribute[spec_key].binary
assert cls.disable_redistribute[spec_key].source
spack.directives._execute_redistribute(cls, binary=False, when="@1.0")
spack.directives._execute_redistribute(cls, source=None, binary=False, when="@1.0")
assert cls.disable_redistribute[spec_key].binary
assert cls.disable_redistribute[spec_key].source