package_base: generify accessor methods for when-keyed dictionaries

This turns some variant-specific methods for dealing with when-keyed dictionaries into
more generic versions, in preparation for conditional version definitions.

`_by_name`, `_names`, etc. are replaced with generic methods for transforming
when-keyed dictionaries:
 * `_by_subkey()`
 * `_subkeys()`
 * `_num_definitions()`
 * `_definitions()`
 * `_remove_overridden_defs()`

And the variant accessors are refactored to use these methods underneath.

To do this, types like `WhenDict` had to be generified, and some `TypeVars`
were added for sortable keys and values.

Signed-off-by: Todd Gamblin <tgamblin@llnl.gov>
This commit is contained in:
Todd Gamblin 2024-11-30 14:55:45 -08:00
parent 7f24b11675
commit 6e2625ae65
3 changed files with 138 additions and 98 deletions

View File

@ -58,6 +58,7 @@
from spack.solver.version_order import concretization_version_order
from spack.stage import DevelopStage, ResourceStage, Stage, StageComposite, compute_stage_name
from spack.util.package_hash import package_hash
from spack.util.typing import SupportsRichComparison
from spack.version import GitVersion, StandardVersion
FLAG_HANDLER_RETURN_TYPE = Tuple[
@ -85,32 +86,6 @@
spack_times_log = "install_times.json"
def deprecated_version(pkg: "PackageBase", version: Union[str, StandardVersion]) -> bool:
"""Return True iff the version is deprecated.
Arguments:
pkg: The package whose version is to be checked.
version: The version being checked
"""
if not isinstance(version, StandardVersion):
version = StandardVersion.from_string(version)
details = pkg.versions.get(version)
return details is not None and details.get("deprecated", False)
def preferred_version(pkg: "PackageBase"):
"""
Returns a sorted list of the preferred versions of the package.
Arguments:
pkg: The package whose versions are to be assessed.
"""
version, _ = max(pkg.versions.items(), key=concretization_version_order)
return version
class WindowsRPath:
"""Collection of functionality surrounding Windows RPATH specific features
@ -415,59 +390,77 @@ def remove_files_from_view(self, view, merge_map):
Pb = TypeVar("Pb", bound="PackageBase")
WhenDict = Dict[spack.spec.Spec, Dict[str, Any]]
NameValuesDict = Dict[str, List[Any]]
NameWhenDict = Dict[str, Dict[spack.spec.Spec, List[Any]]]
# Some typedefs for dealing with when-indexed dictionaries
#
# Many of the dictionaries on PackageBase are of the form:
# { Spec: { K: V } }
#
# K might be a variant name, a version, etc. V is a definition of some Spack object.
# The methods below transform these types of dictionaries.
K = TypeVar("K", bound=SupportsRichComparison)
V = TypeVar("V")
def _by_name(
when_indexed_dictionary: WhenDict, when: bool = False
) -> Union[NameValuesDict, NameWhenDict]:
"""Convert a dict of dicts keyed by when/name into a dict of lists keyed by name.
def _by_subkey(
when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], when: bool = False
) -> Dict[K, Union[List[V], Dict[spack.spec.Spec, List[V]]]]:
"""Convert a dict of dicts keyed by when/subkey into a dict of lists keyed by subkey.
Optional Arguments:
when: if ``True``, don't discared the ``when`` specs; return a 2-level dictionary
keyed by name and when spec.
keyed by subkey and when spec.
"""
# very hard to define this type to be conditional on `when`
all_by_name: Dict[str, Any] = {}
all_by_subkey: Dict[K, Any] = {}
for when_spec, by_name in when_indexed_dictionary.items():
for name, value in by_name.items():
for when_spec, by_key in when_indexed_dictionary.items():
for key, value in by_key.items():
if when:
when_dict = all_by_name.setdefault(name, {})
when_dict = all_by_subkey.setdefault(key, {})
when_dict.setdefault(when_spec, []).append(value)
else:
all_by_name.setdefault(name, []).append(value)
all_by_subkey.setdefault(key, []).append(value)
# this needs to preserve the insertion order of whens
return dict(sorted(all_by_name.items()))
return dict(sorted(all_by_subkey.items()))
def _names(when_indexed_dictionary: WhenDict) -> List[str]:
def _subkeys(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]]) -> List[K]:
"""Get sorted names from dicts keyed by when/name."""
all_names = set()
for when, by_name in when_indexed_dictionary.items():
for name in by_name:
all_names.add(name)
all_keys = set()
for when, by_key in when_indexed_dictionary.items():
for key in by_key:
all_keys.add(key)
return sorted(all_names)
return sorted(all_keys)
WhenVariantList = List[Tuple[spack.spec.Spec, spack.variant.Variant]]
def _has_subkey(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], key: K) -> bool:
return any(key in dictionary for dictionary in when_indexed_dictionary.values())
def _remove_overridden_vdefs(variant_defs: WhenVariantList) -> None:
"""Remove variant defs from the list if their when specs are satisfied by later ones.
def _num_definitions(when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]]) -> int:
return sum(len(dictionary) for dictionary in when_indexed_dictionary.values())
Any such variant definitions are *always* overridden by their successor, as it will
match everything the predecessor matches, and the solver will prefer it because of
its higher precedence.
We can just remove these defs from variant definitions and avoid putting them in the
solver. This is also useful for, e.g., `spack info`, where we don't want to show a
variant from a superclass if it is always overridden by a variant defined in a
subclass.
def _precedence(obj) -> int:
"""Get either a 'precedence' attribute or item from an object."""
precedence = getattr(obj, "precedence", None)
if precedence is None:
raise KeyError(f"Couldn't get precedence from {type(obj)}")
return precedence
def _remove_overridden_defs(defs: List[Tuple[spack.spec.Spec, Any]]) -> None:
"""Remove definitions from the list if their when specs are satisfied by later ones.
Any such definitions are *always* overridden by their successor, as they will
match everything the predecessor matches, and the solver will prefer them because of
their higher precedence.
We can just remove these defs and avoid putting them in the solver. This is also
useful for, e.g., `spack info`, where we don't want to show a variant from a
superclass if it is always overridden by a variant defined in a subclass.
Example::
@ -485,14 +478,33 @@ class Hipblas:
"""
i = 0
while i < len(variant_defs):
when, vdef = variant_defs[i]
if any(when.satisfies(successor) for successor, _ in variant_defs[i + 1 :]):
del variant_defs[i]
while i < len(defs):
when, _ = defs[i]
if any(when.satisfies(successor) for successor, _ in defs[i + 1 :]):
del defs[i]
else:
i += 1
def _definitions(
when_indexed_dictionary: Dict[spack.spec.Spec, Dict[K, V]], key: K
) -> List[Tuple[spack.spec.Spec, V]]:
"""Iterator over (when_spec, Value) for all values with a particular Key."""
# construct a list of defs sorted by precedence
defs: List[Tuple[spack.spec.Spec, V]] = []
for when, values_by_key in when_indexed_dictionary.items():
value_def = values_by_key.get(key)
if value_def:
defs.append((when, value_def))
# With multiple definitions, ensure precedence order and simplify overrides
if len(defs) > 1:
defs.sort(key=lambda v: _precedence(v[1]))
_remove_overridden_defs(defs)
return defs
#: Store whether a given Spec source/binary should not be redistributed.
class DisableRedistribute:
def __init__(self, source, binary):
@ -756,44 +768,32 @@ def __init__(self, spec):
@classmethod
def dependency_names(cls):
return _names(cls.dependencies)
return _subkeys(cls.dependencies)
@classmethod
def dependencies_by_name(cls, when: bool = False):
return _by_name(cls.dependencies, when=when)
return _by_subkey(cls.dependencies, when=when)
# Accessors for variants
# External code workingw with Variants should go through the methods below
# External code working with Variants should go through the methods below
@classmethod
def variant_names(cls) -> List[str]:
return _names(cls.variants)
return _subkeys(cls.variants)
@classmethod
def has_variant(cls, name) -> bool:
return any(name in dictionary for dictionary in cls.variants.values())
return _has_subkey(cls.variants, name)
@classmethod
def num_variant_definitions(cls) -> int:
"""Total number of variant definitions in this class so far."""
return sum(len(variants_by_name) for variants_by_name in cls.variants.values())
return _num_definitions(cls.variants)
@classmethod
def variant_definitions(cls, name: str) -> WhenVariantList:
def variant_definitions(cls, name: str) -> List[Tuple[spack.spec.Spec, spack.variant.Variant]]:
"""Iterator over (when_spec, Variant) for all variant definitions for a particular name."""
# construct a list of defs sorted by precedence
defs: WhenVariantList = []
for when, variants_by_name in cls.variants.items():
variant_def = variants_by_name.get(name)
if variant_def:
defs.append((when, variant_def))
# With multiple definitions, ensure precedence order and simplify overrides
if len(defs) > 1:
defs.sort(key=lambda v: v[1].precedence)
_remove_overridden_vdefs(defs)
return defs
return _definitions(cls.variants, name)
@classmethod
def variant_items(cls) -> Iterable[Tuple[spack.spec.Spec, Dict[str, spack.variant.Variant]]]:
@ -2369,6 +2369,32 @@ def possible_dependencies(
return visited
def deprecated_version(pkg: PackageBase, version: Union[str, StandardVersion]) -> bool:
"""Return True iff the version is deprecated.
Arguments:
pkg: The package whose version is to be checked.
version: The version being checked
"""
if not isinstance(version, StandardVersion):
version = StandardVersion.from_string(version)
details = pkg.versions.get(version)
return details is not None and details.get("deprecated", False)
def preferred_version(pkg: PackageBase):
"""
Returns a sorted list of the preferred versions of the package.
Arguments:
pkg: The package whose versions are to be assessed.
"""
version, _ = max(pkg.versions.items(), key=concretization_version_order)
return version
class PackageStillNeededError(InstallError):
"""Raised when package is still needed by another on uninstall."""

View File

@ -0,0 +1,30 @@
# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other: object
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from typing import Any
from typing_extensions import Protocol
class SupportsRichComparison(Protocol):
"""Objects that support =, !=, <, <=, >, and >=."""
def __eq__(self, other: Any) -> bool:
raise NotImplementedError
def __ne__(self, other: Any) -> bool:
raise NotImplementedError
def __lt__(self, other: Any) -> bool:
raise NotImplementedError
def __le__(self, other: Any) -> bool:
raise NotImplementedError
def __gt__(self, other: Any) -> bool:
raise NotImplementedError
def __ge__(self, other: Any) -> bool:
raise NotImplementedError

View File

@ -8,6 +8,7 @@
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union
from spack.util.spack_yaml import syaml_dict
from spack.util.typing import SupportsRichComparison
from .common import (
ALPHA,
@ -156,7 +157,7 @@ def parse_string_components(string: str) -> Tuple[VersionTuple, SeparatorTuple]:
return (release, prerelease), separators
class VersionType:
class VersionType(SupportsRichComparison):
"""Base type for all versions in Spack (ranges, lists, regular versions, and git versions).
Versions in Spack behave like sets, and support some basic set operations. There are
@ -193,23 +194,6 @@ def union(self, other: "VersionType") -> "VersionType":
"""Return a VersionType containing self and other."""
raise NotImplementedError
# We can use SupportsRichComparisonT in Python 3.8 or later, but alas in 3.6 we need
# to write all the operators out
def __eq__(self, other: object) -> bool:
raise NotImplementedError
def __lt__(self, other: object) -> bool:
raise NotImplementedError
def __gt__(self, other: object) -> bool:
raise NotImplementedError
def __ge__(self, other: object) -> bool:
raise NotImplementedError
def __le__(self, other: object) -> bool:
raise NotImplementedError
def __hash__(self) -> int:
raise NotImplementedError