diff --git a/lib/spack/spack/package_base.py b/lib/spack/spack/package_base.py index 305af5cb8c1..7ffb06ab4c5 100644 --- a/lib/spack/spack/package_base.py +++ b/lib/spack/spack/package_base.py @@ -58,6 +58,7 @@ from spack.solver.version_order import concretization_version_order from spack.stage import DevelopStage, ResourceStage, Stage, StageComposite, compute_stage_name from spack.util.package_hash import package_hash +from spack.util.typing import SupportsRichComparison from spack.version import GitVersion, StandardVersion FLAG_HANDLER_RETURN_TYPE = Tuple[ @@ -85,32 +86,6 @@ spack_times_log = "install_times.json" -def deprecated_version(pkg: "PackageBase", version: Union[str, StandardVersion]) -> bool: - """Return True iff the version is deprecated. - - Arguments: - pkg: The package whose version is to be checked. - version: The version being checked - """ - if not isinstance(version, StandardVersion): - version = StandardVersion.from_string(version) - - details = pkg.versions.get(version) - return details is not None and details.get("deprecated", False) - - -def preferred_version(pkg: "PackageBase"): - """ - Returns a sorted list of the preferred versions of the package. - - Arguments: - pkg: The package whose versions are to be assessed. - """ - - version, _ = max(pkg.versions.items(), key=concretization_version_order) - return version - - class WindowsRPath: """Collection of functionality surrounding Windows RPATH specific features @@ -415,59 +390,77 @@ def remove_files_from_view(self, view, merge_map): Pb = TypeVar("Pb", bound="PackageBase") -WhenDict = Dict[spack.spec.Spec, Dict[str, Any]] -NameValuesDict = Dict[str, List[Any]] -NameWhenDict = Dict[str, Dict[spack.spec.Spec, List[Any]]] +# Some typedefs for dealing with when-indexed dictionaries +# +# Many of the dictionaries on PackageBase are of the form: +# { Spec: { K: V } } +# +# K might be a variant name, a version, etc. V is a definition of some Spack object. +# The methods below transform these types of dictionaries. +K = TypeVar("K", bound=SupportsRichComparison) +V = TypeVar("V") -def _by_name( - when_indexed_dictionary: WhenDict, when: bool = False -) -> Union[NameValuesDict, NameWhenDict]: - """Convert a dict of dicts keyed by when/name into a dict of lists keyed by name. +def _by_subkey( + when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], when: bool = False +) -> Dict[K, Union[List[V], Dict[spack.spec.Spec, List[V]]]]: + """Convert a dict of dicts keyed by when/subkey into a dict of lists keyed by subkey. Optional Arguments: when: if ``True``, don't discared the ``when`` specs; return a 2-level dictionary - keyed by name and when spec. + keyed by subkey and when spec. """ # very hard to define this type to be conditional on `when` - all_by_name: Dict[str, Any] = {} + all_by_subkey: Dict[K, Any] = {} - for when_spec, by_name in when_indexed_dictionary.items(): - for name, value in by_name.items(): + for when_spec, by_key in when_indexed_dictionary.items(): + for key, value in by_key.items(): if when: - when_dict = all_by_name.setdefault(name, {}) + when_dict = all_by_subkey.setdefault(key, {}) when_dict.setdefault(when_spec, []).append(value) else: - all_by_name.setdefault(name, []).append(value) + all_by_subkey.setdefault(key, []).append(value) # this needs to preserve the insertion order of whens - return dict(sorted(all_by_name.items())) + return dict(sorted(all_by_subkey.items())) -def _names(when_indexed_dictionary: WhenDict) -> List[str]: +def _subkeys(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]]) -> List[K]: """Get sorted names from dicts keyed by when/name.""" - all_names = set() - for when, by_name in when_indexed_dictionary.items(): - for name in by_name: - all_names.add(name) + all_keys = set() + for when, by_key in when_indexed_dictionary.items(): + for key in by_key: + all_keys.add(key) - return sorted(all_names) + return sorted(all_keys) -WhenVariantList = List[Tuple[spack.spec.Spec, spack.variant.Variant]] +def _has_subkey(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], key: K) -> bool: + return any(key in dictionary for dictionary in when_indexed_dictionary.values()) -def _remove_overridden_vdefs(variant_defs: WhenVariantList) -> None: - """Remove variant defs from the list if their when specs are satisfied by later ones. +def _num_definitions(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]]) -> int: + return sum(len(dictionary) for dictionary in when_indexed_dictionary.values()) - Any such variant definitions are *always* overridden by their successor, as it will - match everything the predecessor matches, and the solver will prefer it because of - its higher precedence. - We can just remove these defs from variant definitions and avoid putting them in the - solver. This is also useful for, e.g., `spack info`, where we don't want to show a - variant from a superclass if it is always overridden by a variant defined in a - subclass. +def _precedence(obj) -> int: + """Get either a 'precedence' attribute or item from an object.""" + precedence = getattr(obj, "precedence", None) + if precedence is None: + raise KeyError(f"Couldn't get precedence from {type(obj)}") + return precedence + + +def _remove_overridden_defs(defs: List[Tuple[spack.spec.Spec, Any]]) -> None: + """Remove definitions from the list if their when specs are satisfied by later ones. + + Any such definitions are *always* overridden by their successor, as they will + match everything the predecessor matches, and the solver will prefer them because of + their higher precedence. + + We can just remove these defs and avoid putting them in the solver. This is also + useful for, e.g., `spack info`, where we don't want to show a variant from a + superclass if it is always overridden by a variant defined in a subclass. Example:: @@ -485,14 +478,33 @@ class Hipblas: """ i = 0 - while i < len(variant_defs): - when, vdef = variant_defs[i] - if any(when.satisfies(successor) for successor, _ in variant_defs[i + 1 :]): - del variant_defs[i] + while i < len(defs): + when, _ = defs[i] + if any(when.satisfies(successor) for successor, _ in defs[i + 1 :]): + del defs[i] else: i += 1 +def _definitions( + when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], key: K +) -> List[Tuple[spack.spec.Spec, V]]: + """Iterator over (when_spec, Value) for all values with a particular Key.""" + # construct a list of defs sorted by precedence + defs: List[Tuple[spack.spec.Spec, V]] = [] + for when, values_by_key in when_indexed_dictionary.items(): + value_def = values_by_key.get(key) + if value_def: + defs.append((when, value_def)) + + # With multiple definitions, ensure precedence order and simplify overrides + if len(defs) > 1: + defs.sort(key=lambda v: _precedence(v[1])) + _remove_overridden_defs(defs) + + return defs + + #: Store whether a given Spec source/binary should not be redistributed. class DisableRedistribute: def __init__(self, source, binary): @@ -756,44 +768,32 @@ def __init__(self, spec): @classmethod def dependency_names(cls): - return _names(cls.dependencies) + return _subkeys(cls.dependencies) @classmethod def dependencies_by_name(cls, when: bool = False): - return _by_name(cls.dependencies, when=when) + return _by_subkey(cls.dependencies, when=when) # Accessors for variants - # External code workingw with Variants should go through the methods below + # External code working with Variants should go through the methods below @classmethod def variant_names(cls) -> List[str]: - return _names(cls.variants) + return _subkeys(cls.variants) @classmethod def has_variant(cls, name) -> bool: - return any(name in dictionary for dictionary in cls.variants.values()) + return _has_subkey(cls.variants, name) @classmethod def num_variant_definitions(cls) -> int: """Total number of variant definitions in this class so far.""" - return sum(len(variants_by_name) for variants_by_name in cls.variants.values()) + return _num_definitions(cls.variants) @classmethod - def variant_definitions(cls, name: str) -> WhenVariantList: + def variant_definitions(cls, name: str) -> List[Tuple[spack.spec.Spec, spack.variant.Variant]]: """Iterator over (when_spec, Variant) for all variant definitions for a particular name.""" - # construct a list of defs sorted by precedence - defs: WhenVariantList = [] - for when, variants_by_name in cls.variants.items(): - variant_def = variants_by_name.get(name) - if variant_def: - defs.append((when, variant_def)) - - # With multiple definitions, ensure precedence order and simplify overrides - if len(defs) > 1: - defs.sort(key=lambda v: v[1].precedence) - _remove_overridden_vdefs(defs) - - return defs + return _definitions(cls.variants, name) @classmethod def variant_items(cls) -> Iterable[Tuple[spack.spec.Spec, Dict[str, spack.variant.Variant]]]: @@ -2369,6 +2369,32 @@ def possible_dependencies( return visited +def deprecated_version(pkg: PackageBase, version: Union[str, StandardVersion]) -> bool: + """Return True iff the version is deprecated. + + Arguments: + pkg: The package whose version is to be checked. + version: The version being checked + """ + if not isinstance(version, StandardVersion): + version = StandardVersion.from_string(version) + + details = pkg.versions.get(version) + return details is not None and details.get("deprecated", False) + + +def preferred_version(pkg: PackageBase): + """ + Returns a sorted list of the preferred versions of the package. + + Arguments: + pkg: The package whose versions are to be assessed. + """ + + version, _ = max(pkg.versions.items(), key=concretization_version_order) + return version + + class PackageStillNeededError(InstallError): """Raised when package is still needed by another on uninstall.""" diff --git a/lib/spack/spack/util/typing.py b/lib/spack/spack/util/typing.py new file mode 100644 index 00000000000..7e1525aa4e9 --- /dev/null +++ b/lib/spack/spack/util/typing.py @@ -0,0 +1,30 @@ +# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other: object +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + +from typing import Any + +from typing_extensions import Protocol + + +class SupportsRichComparison(Protocol): + """Objects that support =, !=, <, <=, >, and >=.""" + + def __eq__(self, other: Any) -> bool: + raise NotImplementedError + + def __ne__(self, other: Any) -> bool: + raise NotImplementedError + + def __lt__(self, other: Any) -> bool: + raise NotImplementedError + + def __le__(self, other: Any) -> bool: + raise NotImplementedError + + def __gt__(self, other: Any) -> bool: + raise NotImplementedError + + def __ge__(self, other: Any) -> bool: + raise NotImplementedError diff --git a/lib/spack/spack/version/version_types.py b/lib/spack/spack/version/version_types.py index 4c7a9606f46..b2e27445f9a 100644 --- a/lib/spack/spack/version/version_types.py +++ b/lib/spack/spack/version/version_types.py @@ -8,6 +8,7 @@ from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union from spack.util.spack_yaml import syaml_dict +from spack.util.typing import SupportsRichComparison from .common import ( ALPHA, @@ -156,7 +157,7 @@ def parse_string_components(string: str) -> Tuple[VersionTuple, SeparatorTuple]: return (release, prerelease), separators -class VersionType: +class VersionType(SupportsRichComparison): """Base type for all versions in Spack (ranges, lists, regular versions, and git versions). Versions in Spack behave like sets, and support some basic set operations. There are @@ -193,23 +194,6 @@ def union(self, other: "VersionType") -> "VersionType": """Return a VersionType containing self and other.""" raise NotImplementedError - # We can use SupportsRichComparisonT in Python 3.8 or later, but alas in 3.6 we need - # to write all the operators out - def __eq__(self, other: object) -> bool: - raise NotImplementedError - - def __lt__(self, other: object) -> bool: - raise NotImplementedError - - def __gt__(self, other: object) -> bool: - raise NotImplementedError - - def __ge__(self, other: object) -> bool: - raise NotImplementedError - - def __le__(self, other: object) -> bool: - raise NotImplementedError - def __hash__(self) -> int: raise NotImplementedError