directives: use Type[PackageBase] instead of PackageBase

The first argument to each Spack directive is not a `PackageBase` instance but a
`PackageBase` class object, so fix the type annotations to reflect this.

Signed-off-by: Todd Gamblin <tgamblin@llnl.gov>
This commit is contained in:
Todd Gamblin 2024-11-29 22:55:18 -08:00
parent aa81d59958
commit 175a4bf101
4 changed files with 30 additions and 21 deletions

View File

@ -3,7 +3,7 @@
# #
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
"""Data structures that represent Spack's dependency relationships.""" """Data structures that represent Spack's dependency relationships."""
from typing import Dict, List from typing import Dict, List, Type
import spack.deptypes as dt import spack.deptypes as dt
import spack.spec import spack.spec
@ -38,7 +38,7 @@ class Dependency:
def __init__( def __init__(
self, self,
pkg: "spack.package_base.PackageBase", pkg: Type["spack.package_base.PackageBase"],
spec: "spack.spec.Spec", spec: "spack.spec.Spec",
depflag: dt.DepFlag = dt.DEFAULT, depflag: dt.DepFlag = dt.DEFAULT,
): ):

View File

@ -21,6 +21,7 @@ class OpenMpi(Package):
* ``conflicts`` * ``conflicts``
* ``depends_on`` * ``depends_on``
* ``extends`` * ``extends``
* ``license``
* ``patch`` * ``patch``
* ``provides`` * ``provides``
* ``resource`` * ``resource``
@ -34,7 +35,7 @@ class OpenMpi(Package):
import collections.abc import collections.abc
import os.path import os.path
import re import re
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Tuple, Type, Union
import llnl.util.tty.color import llnl.util.tty.color
@ -81,7 +82,7 @@ class OpenMpi(Package):
SpecType = str SpecType = str
DepType = Union[Tuple[str, ...], str] DepType = Union[Tuple[str, ...], str]
WhenType = Optional[Union[spack.spec.Spec, str, bool]] WhenType = Optional[Union[spack.spec.Spec, str, bool]]
Patcher = Callable[[Union[spack.package_base.PackageBase, Dependency]], None] Patcher = Callable[[Union[Type[spack.package_base.PackageBase], Dependency]], None]
PatchesType = Union[Patcher, str, List[Union[Patcher, str]]] PatchesType = Union[Patcher, str, List[Union[Patcher, str]]]
@ -218,7 +219,7 @@ def version(
return lambda pkg: _execute_version(pkg, ver, **kwargs) return lambda pkg: _execute_version(pkg, ver, **kwargs)
def _execute_version(pkg, ver, **kwargs): def _execute_version(pkg: Type[spack.package_base.PackageBase], ver: Union[str, int], **kwargs):
if ( if (
(any(s in kwargs for s in spack.util.crypto.hashes) or "checksum" in kwargs) (any(s in kwargs for s in spack.util.crypto.hashes) or "checksum" in kwargs)
and hasattr(pkg, "has_code") and hasattr(pkg, "has_code")
@ -249,7 +250,7 @@ def _execute_version(pkg, ver, **kwargs):
def _depends_on( def _depends_on(
pkg: spack.package_base.PackageBase, pkg: Type[spack.package_base.PackageBase],
spec: spack.spec.Spec, spec: spack.spec.Spec,
*, *,
when: WhenType = None, when: WhenType = None,
@ -329,7 +330,7 @@ def conflicts(conflict_spec: SpecType, when: WhenType = None, msg: Optional[str]
msg (str): optional user defined message msg (str): optional user defined message
""" """
def _execute_conflicts(pkg: spack.package_base.PackageBase): def _execute_conflicts(pkg: Type[spack.package_base.PackageBase]):
# If when is not specified the conflict always holds # If when is not specified the conflict always holds
when_spec = _make_when_spec(when) when_spec = _make_when_spec(when)
if not when_spec: if not when_spec:
@ -370,14 +371,16 @@ def depends_on(
assert type == "build", "languages must be of 'build' type" assert type == "build", "languages must be of 'build' type"
return _language(lang_spec_str=spec, when=when) return _language(lang_spec_str=spec, when=when)
def _execute_depends_on(pkg: spack.package_base.PackageBase): def _execute_depends_on(pkg: Type[spack.package_base.PackageBase]):
_depends_on(pkg, dep_spec, when=when, type=type, patches=patches) _depends_on(pkg, dep_spec, when=when, type=type, patches=patches)
return _execute_depends_on return _execute_depends_on
@directive("disable_redistribute") @directive("disable_redistribute")
def redistribute(source=None, binary=None, when: WhenType = None): def redistribute(
source: Optional[bool] = None, binary: Optional[bool] = None, when: WhenType = None
):
"""Can be used inside a Package definition to declare that """Can be used inside a Package definition to declare that
the package source and/or compiled binaries should not be the package source and/or compiled binaries should not be
redistributed. redistributed.
@ -392,7 +395,10 @@ def redistribute(source=None, binary=None, when: WhenType = None):
def _execute_redistribute( def _execute_redistribute(
pkg: spack.package_base.PackageBase, source=None, binary=None, when: WhenType = None pkg: Type[spack.package_base.PackageBase],
source: Optional[bool],
binary: Optional[bool],
when: WhenType,
): ):
if source is None and binary is None: if source is None and binary is None:
return return
@ -468,7 +474,7 @@ def provides(*specs: SpecType, when: WhenType = None):
when: condition when this provides clause needs to be considered when: condition when this provides clause needs to be considered
""" """
def _execute_provides(pkg: spack.package_base.PackageBase): def _execute_provides(pkg: Type[spack.package_base.PackageBase]):
import spack.parser # Avoid circular dependency import spack.parser # Avoid circular dependency
when_spec = _make_when_spec(when) when_spec = _make_when_spec(when)
@ -516,7 +522,7 @@ def can_splice(
variants will be skipped by '*'. variants will be skipped by '*'.
""" """
def _execute_can_splice(pkg: spack.package_base.PackageBase): def _execute_can_splice(pkg: Type[spack.package_base.PackageBase]):
when_spec = _make_when_spec(when) when_spec = _make_when_spec(when)
if isinstance(match_variants, str) and match_variants != "*": if isinstance(match_variants, str) and match_variants != "*":
raise ValueError( raise ValueError(
@ -557,10 +563,10 @@ def patch(
compressed URL patches) compressed URL patches)
""" """
def _execute_patch(pkg_or_dep: Union[spack.package_base.PackageBase, Dependency]): def _execute_patch(
pkg = pkg_or_dep pkg_or_dep: Union[Type[spack.package_base.PackageBase], Dependency]
if isinstance(pkg, Dependency): ) -> None:
pkg = pkg.pkg pkg = pkg_or_dep.pkg if isinstance(pkg_or_dep, Dependency) else pkg_or_dep
if hasattr(pkg, "has_code") and not pkg.has_code: if hasattr(pkg, "has_code") and not pkg.has_code:
raise UnsupportedPackageDirective( raise UnsupportedPackageDirective(
@ -817,7 +823,9 @@ def _execute_maintainer(pkg):
return _execute_maintainer return _execute_maintainer
def _execute_license(pkg, license_identifier: str, when): def _execute_license(
pkg: Type[spack.package_base.PackageBase], license_identifier: str, when: WhenType
):
# If when is not specified the license always holds # If when is not specified the license always holds
when_spec = _make_when_spec(when) when_spec = _make_when_spec(when)
if not when_spec: if not when_spec:
@ -881,7 +889,7 @@ def requires(*requirement_specs: str, policy="one_of", when=None, msg=None):
msg: optional user defined message msg: optional user defined message
""" """
def _execute_requires(pkg: spack.package_base.PackageBase): def _execute_requires(pkg: Type[spack.package_base.PackageBase]):
if policy not in ("one_of", "any_of"): if policy not in ("one_of", "any_of"):
err_msg = ( err_msg = (
f"the 'policy' argument of the 'requires' directive in {pkg.name} is set " f"the 'policy' argument of the 'requires' directive in {pkg.name} is set "
@ -906,7 +914,7 @@ def _execute_requires(pkg: spack.package_base.PackageBase):
def _language(lang_spec_str: str, *, when: Optional[Union[str, bool]] = None): def _language(lang_spec_str: str, *, when: Optional[Union[str, bool]] = None):
"""Temporary implementation of language virtuals, until compilers are proper dependencies.""" """Temporary implementation of language virtuals, until compilers are proper dependencies."""
def _execute_languages(pkg: spack.package_base.PackageBase): def _execute_languages(pkg: Type[spack.package_base.PackageBase]):
when_spec = _make_when_spec(when) when_spec = _make_when_spec(when)
if not when_spec: if not when_spec:
return return

View File

@ -595,6 +595,7 @@ class PackageBase(WindowsRPath, PackageViewMixin, metaclass=PackageMeta):
patches: Dict[spack.spec.Spec, List[spack.patch.Patch]] patches: Dict[spack.spec.Spec, List[spack.patch.Patch]]
variants: Dict[spack.spec.Spec, Dict[str, spack.variant.Variant]] variants: Dict[spack.spec.Spec, Dict[str, spack.variant.Variant]]
languages: Dict[spack.spec.Spec, Set[str]] languages: Dict[spack.spec.Spec, Set[str]]
licenses: Dict[spack.spec.Spec, str]
splice_specs: Dict[spack.spec.Spec, Tuple[spack.spec.Spec, Union[None, str, List[str]]]] splice_specs: Dict[spack.spec.Spec, Tuple[spack.spec.Spec, Union[None, str, List[str]]]]
#: Store whether a given Spec source/binary should not be redistributed. #: Store whether a given Spec source/binary should not be redistributed.

View File

@ -219,10 +219,10 @@ class MockPackage:
disable_redistribute = {} disable_redistribute = {}
cls = MockPackage cls = MockPackage
spack.directives._execute_redistribute(cls, source=False, when="@1.0") spack.directives._execute_redistribute(cls, source=False, binary=None, when="@1.0")
spec_key = spack.directives._make_when_spec("@1.0") spec_key = spack.directives._make_when_spec("@1.0")
assert not cls.disable_redistribute[spec_key].binary assert not cls.disable_redistribute[spec_key].binary
assert cls.disable_redistribute[spec_key].source assert cls.disable_redistribute[spec_key].source
spack.directives._execute_redistribute(cls, binary=False, when="@1.0") spack.directives._execute_redistribute(cls, source=None, binary=False, when="@1.0")
assert cls.disable_redistribute[spec_key].binary assert cls.disable_redistribute[spec_key].binary
assert cls.disable_redistribute[spec_key].source assert cls.disable_redistribute[spec_key].source