package_base: generify accessor methods for when-keyed dictionaries

This turns some variant-specific methods for dealing with when-keyed dictionaries into
more generic versions, in preparation for conditional version definitions.

`_by_name`, `_names`, etc. are replaced with generic methods for transforming
when-keyed dictionaries:
 * `_by_subkey()`
 * `_subkeys()`
 * `_num_definitions()`
 * `_definitions()`
 * `_remove_overridden_defs()`

And the variant accessors are refactored to use these methods underneath.

To do this, types like `WhenDict` had to be generified, and some `TypeVars`
were added for sortable keys and values.

Signed-off-by: Todd Gamblin <tgamblin@llnl.gov>
This commit is contained in:
Todd Gamblin 2024-11-30 14:55:45 -08:00
parent 7f24b11675
commit 6e2625ae65
3 changed files with 138 additions and 98 deletions

View File

@ -58,6 +58,7 @@
from spack.solver.version_order import concretization_version_order from spack.solver.version_order import concretization_version_order
from spack.stage import DevelopStage, ResourceStage, Stage, StageComposite, compute_stage_name from spack.stage import DevelopStage, ResourceStage, Stage, StageComposite, compute_stage_name
from spack.util.package_hash import package_hash from spack.util.package_hash import package_hash
from spack.util.typing import SupportsRichComparison
from spack.version import GitVersion, StandardVersion from spack.version import GitVersion, StandardVersion
FLAG_HANDLER_RETURN_TYPE = Tuple[ FLAG_HANDLER_RETURN_TYPE = Tuple[
@ -85,32 +86,6 @@
spack_times_log = "install_times.json" spack_times_log = "install_times.json"
def deprecated_version(pkg: "PackageBase", version: Union[str, StandardVersion]) -> bool:
"""Return True iff the version is deprecated.
Arguments:
pkg: The package whose version is to be checked.
version: The version being checked
"""
if not isinstance(version, StandardVersion):
version = StandardVersion.from_string(version)
details = pkg.versions.get(version)
return details is not None and details.get("deprecated", False)
def preferred_version(pkg: "PackageBase"):
"""
Returns a sorted list of the preferred versions of the package.
Arguments:
pkg: The package whose versions are to be assessed.
"""
version, _ = max(pkg.versions.items(), key=concretization_version_order)
return version
class WindowsRPath: class WindowsRPath:
"""Collection of functionality surrounding Windows RPATH specific features """Collection of functionality surrounding Windows RPATH specific features
@ -415,59 +390,77 @@ def remove_files_from_view(self, view, merge_map):
Pb = TypeVar("Pb", bound="PackageBase") Pb = TypeVar("Pb", bound="PackageBase")
WhenDict = Dict[spack.spec.Spec, Dict[str, Any]] # Some typedefs for dealing with when-indexed dictionaries
NameValuesDict = Dict[str, List[Any]] #
NameWhenDict = Dict[str, Dict[spack.spec.Spec, List[Any]]] # Many of the dictionaries on PackageBase are of the form:
# { Spec: { K: V } }
#
# K might be a variant name, a version, etc. V is a definition of some Spack object.
# The methods below transform these types of dictionaries.
K = TypeVar("K", bound=SupportsRichComparison)
V = TypeVar("V")
def _by_name( def _by_subkey(
when_indexed_dictionary: WhenDict, when: bool = False when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], when: bool = False
) -> Union[NameValuesDict, NameWhenDict]: ) -> Dict[K, Union[List[V], Dict[spack.spec.Spec, List[V]]]]:
"""Convert a dict of dicts keyed by when/name into a dict of lists keyed by name. """Convert a dict of dicts keyed by when/subkey into a dict of lists keyed by subkey.
Optional Arguments: Optional Arguments:
when: if ``True``, don't discared the ``when`` specs; return a 2-level dictionary when: if ``True``, don't discared the ``when`` specs; return a 2-level dictionary
keyed by name and when spec. keyed by subkey and when spec.
""" """
# very hard to define this type to be conditional on `when` # very hard to define this type to be conditional on `when`
all_by_name: Dict[str, Any] = {} all_by_subkey: Dict[K, Any] = {}
for when_spec, by_name in when_indexed_dictionary.items(): for when_spec, by_key in when_indexed_dictionary.items():
for name, value in by_name.items(): for key, value in by_key.items():
if when: if when:
when_dict = all_by_name.setdefault(name, {}) when_dict = all_by_subkey.setdefault(key, {})
when_dict.setdefault(when_spec, []).append(value) when_dict.setdefault(when_spec, []).append(value)
else: else:
all_by_name.setdefault(name, []).append(value) all_by_subkey.setdefault(key, []).append(value)
# this needs to preserve the insertion order of whens # this needs to preserve the insertion order of whens
return dict(sorted(all_by_name.items())) return dict(sorted(all_by_subkey.items()))
def _names(when_indexed_dictionary: WhenDict) -> List[str]: def _subkeys(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]]) -> List[K]:
"""Get sorted names from dicts keyed by when/name.""" """Get sorted names from dicts keyed by when/name."""
all_names = set() all_keys = set()
for when, by_name in when_indexed_dictionary.items(): for when, by_key in when_indexed_dictionary.items():
for name in by_name: for key in by_key:
all_names.add(name) all_keys.add(key)
return sorted(all_names) return sorted(all_keys)
WhenVariantList = List[Tuple[spack.spec.Spec, spack.variant.Variant]] def _has_subkey(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], key: K) -> bool:
return any(key in dictionary for dictionary in when_indexed_dictionary.values())
def _remove_overridden_vdefs(variant_defs: WhenVariantList) -> None: def _num_definitions(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]]) -> int:
"""Remove variant defs from the list if their when specs are satisfied by later ones. return sum(len(dictionary) for dictionary in when_indexed_dictionary.values())
Any such variant definitions are *always* overridden by their successor, as it will
match everything the predecessor matches, and the solver will prefer it because of
its higher precedence.
We can just remove these defs from variant definitions and avoid putting them in the def _precedence(obj) -> int:
solver. This is also useful for, e.g., `spack info`, where we don't want to show a """Get either a 'precedence' attribute or item from an object."""
variant from a superclass if it is always overridden by a variant defined in a precedence = getattr(obj, "precedence", None)
subclass. if precedence is None:
raise KeyError(f"Couldn't get precedence from {type(obj)}")
return precedence
def _remove_overridden_defs(defs: List[Tuple[spack.spec.Spec, Any]]) -> None:
"""Remove definitions from the list if their when specs are satisfied by later ones.
Any such definitions are *always* overridden by their successor, as they will
match everything the predecessor matches, and the solver will prefer them because of
their higher precedence.
We can just remove these defs and avoid putting them in the solver. This is also
useful for, e.g., `spack info`, where we don't want to show a variant from a
superclass if it is always overridden by a variant defined in a subclass.
Example:: Example::
@ -485,14 +478,33 @@ class Hipblas:
""" """
i = 0 i = 0
while i < len(variant_defs): while i < len(defs):
when, vdef = variant_defs[i] when, _ = defs[i]
if any(when.satisfies(successor) for successor, _ in variant_defs[i + 1 :]): if any(when.satisfies(successor) for successor, _ in defs[i + 1 :]):
del variant_defs[i] del defs[i]
else: else:
i += 1 i += 1
def _definitions(
when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], key: K
) -> List[Tuple[spack.spec.Spec, V]]:
"""Iterator over (when_spec, Value) for all values with a particular Key."""
# construct a list of defs sorted by precedence
defs: List[Tuple[spack.spec.Spec, V]] = []
for when, values_by_key in when_indexed_dictionary.items():
value_def = values_by_key.get(key)
if value_def:
defs.append((when, value_def))
# With multiple definitions, ensure precedence order and simplify overrides
if len(defs) > 1:
defs.sort(key=lambda v: _precedence(v[1]))
_remove_overridden_defs(defs)
return defs
#: Store whether a given Spec source/binary should not be redistributed. #: Store whether a given Spec source/binary should not be redistributed.
class DisableRedistribute: class DisableRedistribute:
def __init__(self, source, binary): def __init__(self, source, binary):
@ -756,44 +768,32 @@ def __init__(self, spec):
@classmethod @classmethod
def dependency_names(cls): def dependency_names(cls):
return _names(cls.dependencies) return _subkeys(cls.dependencies)
@classmethod @classmethod
def dependencies_by_name(cls, when: bool = False): def dependencies_by_name(cls, when: bool = False):
return _by_name(cls.dependencies, when=when) return _by_subkey(cls.dependencies, when=when)
# Accessors for variants # Accessors for variants
# External code workingw with Variants should go through the methods below # External code working with Variants should go through the methods below
@classmethod @classmethod
def variant_names(cls) -> List[str]: def variant_names(cls) -> List[str]:
return _names(cls.variants) return _subkeys(cls.variants)
@classmethod @classmethod
def has_variant(cls, name) -> bool: def has_variant(cls, name) -> bool:
return any(name in dictionary for dictionary in cls.variants.values()) return _has_subkey(cls.variants, name)
@classmethod @classmethod
def num_variant_definitions(cls) -> int: def num_variant_definitions(cls) -> int:
"""Total number of variant definitions in this class so far.""" """Total number of variant definitions in this class so far."""
return sum(len(variants_by_name) for variants_by_name in cls.variants.values()) return _num_definitions(cls.variants)
@classmethod @classmethod
def variant_definitions(cls, name: str) -> WhenVariantList: def variant_definitions(cls, name: str) -> List[Tuple[spack.spec.Spec, spack.variant.Variant]]:
"""Iterator over (when_spec, Variant) for all variant definitions for a particular name.""" """Iterator over (when_spec, Variant) for all variant definitions for a particular name."""
# construct a list of defs sorted by precedence return _definitions(cls.variants, name)
defs: WhenVariantList = []
for when, variants_by_name in cls.variants.items():
variant_def = variants_by_name.get(name)
if variant_def:
defs.append((when, variant_def))
# With multiple definitions, ensure precedence order and simplify overrides
if len(defs) > 1:
defs.sort(key=lambda v: v[1].precedence)
_remove_overridden_vdefs(defs)
return defs
@classmethod @classmethod
def variant_items(cls) -> Iterable[Tuple[spack.spec.Spec, Dict[str, spack.variant.Variant]]]: def variant_items(cls) -> Iterable[Tuple[spack.spec.Spec, Dict[str, spack.variant.Variant]]]:
@ -2369,6 +2369,32 @@ def possible_dependencies(
return visited return visited
def deprecated_version(pkg: PackageBase, version: Union[str, StandardVersion]) -> bool:
"""Return True iff the version is deprecated.
Arguments:
pkg: The package whose version is to be checked.
version: The version being checked
"""
if not isinstance(version, StandardVersion):
version = StandardVersion.from_string(version)
details = pkg.versions.get(version)
return details is not None and details.get("deprecated", False)
def preferred_version(pkg: PackageBase):
"""
Returns a sorted list of the preferred versions of the package.
Arguments:
pkg: The package whose versions are to be assessed.
"""
version, _ = max(pkg.versions.items(), key=concretization_version_order)
return version
class PackageStillNeededError(InstallError): class PackageStillNeededError(InstallError):
"""Raised when package is still needed by another on uninstall.""" """Raised when package is still needed by another on uninstall."""

View File

@ -0,0 +1,30 @@
# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other: object
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from typing import Any
from typing_extensions import Protocol
class SupportsRichComparison(Protocol):
"""Objects that support =, !=, <, <=, >, and >=."""
def __eq__(self, other: Any) -> bool:
raise NotImplementedError
def __ne__(self, other: Any) -> bool:
raise NotImplementedError
def __lt__(self, other: Any) -> bool:
raise NotImplementedError
def __le__(self, other: Any) -> bool:
raise NotImplementedError
def __gt__(self, other: Any) -> bool:
raise NotImplementedError
def __ge__(self, other: Any) -> bool:
raise NotImplementedError

View File

@ -8,6 +8,7 @@
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union
from spack.util.spack_yaml import syaml_dict from spack.util.spack_yaml import syaml_dict
from spack.util.typing import SupportsRichComparison
from .common import ( from .common import (
ALPHA, ALPHA,
@ -156,7 +157,7 @@ def parse_string_components(string: str) -> Tuple[VersionTuple, SeparatorTuple]:
return (release, prerelease), separators return (release, prerelease), separators
class VersionType: class VersionType(SupportsRichComparison):
"""Base type for all versions in Spack (ranges, lists, regular versions, and git versions). """Base type for all versions in Spack (ranges, lists, regular versions, and git versions).
Versions in Spack behave like sets, and support some basic set operations. There are Versions in Spack behave like sets, and support some basic set operations. There are
@ -193,23 +194,6 @@ def union(self, other: "VersionType") -> "VersionType":
"""Return a VersionType containing self and other.""" """Return a VersionType containing self and other."""
raise NotImplementedError raise NotImplementedError
# We can use SupportsRichComparisonT in Python 3.8 or later, but alas in 3.6 we need
# to write all the operators out
def __eq__(self, other: object) -> bool:
raise NotImplementedError
def __lt__(self, other: object) -> bool:
raise NotImplementedError
def __gt__(self, other: object) -> bool:
raise NotImplementedError
def __ge__(self, other: object) -> bool:
raise NotImplementedError
def __le__(self, other: object) -> bool:
raise NotImplementedError
def __hash__(self) -> int: def __hash__(self) -> int:
raise NotImplementedError raise NotImplementedError