Allow conditional variants (#24858)

A common question from users has been how to model variants 
that are new in new versions of a package, or variants that are 
dependent on other variants. Our stock answer so far has been
an unsatisfying combination of "just have it do nothing in the old
version" and "tell Spack it conflicts".

This PR enables conditional variants, on any spec condition. The 
syntax is straightforward, and matches that of previous features.
This commit is contained in:
Greg Becker 2021-11-03 00:11:31 -07:00 committed by GitHub
parent 78c08fccd5
commit 67cd92e6a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 217 additions and 32 deletions

View File

@ -1419,6 +1419,60 @@ other similar operations:
).with_default('auto').with_non_feature_values('auto'), ).with_default('auto').with_non_feature_values('auto'),
) )
^^^^^^^^^^^^^^^^^^^^
Conditional Variants
^^^^^^^^^^^^^^^^^^^^
The variant directive accepts a ``when`` clause. The variant will only
be present on specs that otherwise satisfy the spec listed as the
``when`` clause. For example, the following class has a variant
``bar`` when it is at version 2.0 or higher.
.. code-block:: python
class Foo(Package):
...
variant('bar', default=False, when='@2.0:', description='help message')
The ``when`` clause follows the same syntax and accepts the same
values as the ``when`` argument of
:py:func:`spack.directives.depends_on`
^^^^^^^^^^^^^^^^^^^
Overriding Variants
^^^^^^^^^^^^^^^^^^^
Packages may override variants for several reasons, most often to
change the default from a variant defined in a parent class or to
change the conditions under which a variant is present on the spec.
When a variant is defined multiple times, whether in the same package
file or in a subclass and a superclass, the last definition is used
for all attributes **except** for the ``when`` clauses. The ``when``
clauses are accumulated through all invocations, and the variant is
present on the spec if any of the accumulated conditions are
satisfied.
For example, consider the following package:
.. code-block:: python
class Foo(Package):
...
variant('bar', default=False, when='@1.0', description='help1')
variant('bar', default=True, when='platform=darwin', description='help2')
...
This package ``foo`` has a variant ``bar`` when the spec satisfies
either ``@1.0`` or ``platform=darwin``, but not for other platforms at
other versions. The default for this variant, when it is present, is
always ``True``, regardless of which condition of the variant is
satisfied. This allows packages to override variants in packages or
build system classes from which they inherit, by modifying the variant
values without modifying the ``when`` clause. It also allows a package
to implement ``or`` semantics for a variant ``when`` clause by
duplicating the variant definition.
------------------------------------ ------------------------------------
Resources (expanding extra tarballs) Resources (expanding extra tarballs)
------------------------------------ ------------------------------------

View File

@ -521,7 +521,8 @@ def _activate_or_not(
# Create a list of pairs. Each pair includes a configuration # Create a list of pairs. Each pair includes a configuration
# option and whether or not that option is activated # option and whether or not that option is activated
if set(self.variants[variant].values) == set((True, False)): variant_desc, _ = self.variants[variant]
if set(variant_desc.values) == set((True, False)):
# BoolValuedVariant carry information about a single option. # BoolValuedVariant carry information about a single option.
# Nonetheless, for uniformity of treatment we'll package them # Nonetheless, for uniformity of treatment we'll package them
# in an iterable of one element. # in an iterable of one element.
@ -534,8 +535,8 @@ def _activate_or_not(
# package's build system. It excludes values which have special # package's build system. It excludes values which have special
# meanings and do not correspond to features (e.g. "none") # meanings and do not correspond to features (e.g. "none")
feature_values = getattr( feature_values = getattr(
self.variants[variant].values, 'feature_values', None variant_desc.values, 'feature_values', None
) or self.variants[variant].values ) or variant_desc.values
options = [ options = [
(value, (value,

View File

@ -433,7 +433,8 @@ def config_prefer_upstream(args):
or var_name not in spec.package.variants): or var_name not in spec.package.variants):
continue continue
if variant.value != spec.package.variants[var_name].default: variant_desc, _ = spec.package.variants[var_name]
if variant.value != variant_desc.default:
variants.append(str(variant)) variants.append(str(variant))
variants.sort() variants.sort()
variants = ' '.join(variants) variants = ' '.join(variants)

View File

@ -57,7 +57,7 @@ def variant(s):
class VariantFormatter(object): class VariantFormatter(object):
def __init__(self, variants): def __init__(self, variants):
self.variants = variants self.variants = variants
self.headers = ('Name [Default]', 'Allowed values', 'Description') self.headers = ('Name [Default]', 'When', 'Allowed values', 'Description')
# Formats # Formats
fmt_name = '{0} [{1}]' fmt_name = '{0} [{1}]'
@ -68,9 +68,11 @@ def __init__(self, variants):
self.column_widths = [len(x) for x in self.headers] self.column_widths = [len(x) for x in self.headers]
# Expand columns based on max line lengths # Expand columns based on max line lengths
for k, v in variants.items(): for k, e in variants.items():
v, w = e
candidate_max_widths = ( candidate_max_widths = (
len(fmt_name.format(k, self.default(v))), # Name [Default] len(fmt_name.format(k, self.default(v))), # Name [Default]
len(str(w)),
len(v.allowed_values), # Allowed values len(v.allowed_values), # Allowed values
len(v.description) # Description len(v.description) # Description
) )
@ -78,26 +80,29 @@ def __init__(self, variants):
self.column_widths = ( self.column_widths = (
max(self.column_widths[0], candidate_max_widths[0]), max(self.column_widths[0], candidate_max_widths[0]),
max(self.column_widths[1], candidate_max_widths[1]), max(self.column_widths[1], candidate_max_widths[1]),
max(self.column_widths[2], candidate_max_widths[2]) max(self.column_widths[2], candidate_max_widths[2]),
max(self.column_widths[3], candidate_max_widths[3])
) )
# Don't let name or possible values be less than max widths # Don't let name or possible values be less than max widths
_, cols = tty.terminal_size() _, cols = tty.terminal_size()
max_name = min(self.column_widths[0], 30) max_name = min(self.column_widths[0], 30)
max_vals = min(self.column_widths[1], 20) max_when = min(self.column_widths[1], 30)
max_vals = min(self.column_widths[2], 20)
# allow the description column to extend as wide as the terminal. # allow the description column to extend as wide as the terminal.
max_description = min( max_description = min(
self.column_widths[2], self.column_widths[3],
# min width 70 cols, 14 cols of margins and column spacing # min width 70 cols, 14 cols of margins and column spacing
max(cols, 70) - max_name - max_vals - 14, max(cols, 70) - max_name - max_vals - 14,
) )
self.column_widths = (max_name, max_vals, max_description) self.column_widths = (max_name, max_when, max_vals, max_description)
# Compute the format # Compute the format
self.fmt = "%%-%ss%%-%ss%%s" % ( self.fmt = "%%-%ss%%-%ss%%-%ss%%s" % (
self.column_widths[0] + 4, self.column_widths[0] + 4,
self.column_widths[1] + 4 self.column_widths[1] + 4,
self.column_widths[2] + 4
) )
def default(self, v): def default(self, v):
@ -115,21 +120,27 @@ def lines(self):
underline = tuple([w * "=" for w in self.column_widths]) underline = tuple([w * "=" for w in self.column_widths])
yield ' ' + self.fmt % underline yield ' ' + self.fmt % underline
yield '' yield ''
for k, v in sorted(self.variants.items()): for k, e in sorted(self.variants.items()):
v, w = e
name = textwrap.wrap( name = textwrap.wrap(
'{0} [{1}]'.format(k, self.default(v)), '{0} [{1}]'.format(k, self.default(v)),
width=self.column_widths[0] width=self.column_widths[0]
) )
if len(w) == 1:
w = w[0]
if w == spack.spec.Spec():
w = '--'
when = textwrap.wrap(str(w), width=self.column_widths[1])
allowed = v.allowed_values.replace('True, False', 'on, off') allowed = v.allowed_values.replace('True, False', 'on, off')
allowed = textwrap.wrap(allowed, width=self.column_widths[1]) allowed = textwrap.wrap(allowed, width=self.column_widths[2])
description = [] description = []
for d_line in v.description.split('\n'): for d_line in v.description.split('\n'):
description += textwrap.wrap( description += textwrap.wrap(
d_line, d_line,
width=self.column_widths[2] width=self.column_widths[3]
) )
for t in zip_longest( for t in zip_longest(
name, allowed, description, fillvalue='' name, when, allowed, description, fillvalue=''
): ):
yield " " + self.fmt % t yield " " + self.fmt % t
@ -232,7 +243,7 @@ def print_text_info(pkg):
formatter = VariantFormatter(pkg.variants) formatter = VariantFormatter(pkg.variants)
for line in formatter.lines: for line in formatter.lines:
color.cprint(line) color.cprint(color.cescape(line))
if hasattr(pkg, 'phases') and pkg.phases: if hasattr(pkg, 'phases') and pkg.phases:
color.cprint('') color.cprint('')

View File

@ -384,7 +384,8 @@ def concretize_variants(self, spec):
changed = False changed = False
preferred_variants = PackagePrefs.preferred_variants(spec.name) preferred_variants = PackagePrefs.preferred_variants(spec.name)
pkg_cls = spec.package_class pkg_cls = spec.package_class
for name, variant in pkg_cls.variants.items(): for name, entry in pkg_cls.variants.items():
variant, when = entry
var = spec.variants.get(name, None) var = spec.variants.get(name, None)
if var and '*' in var: if var and '*' in var:
# remove variant wildcard before concretizing # remove variant wildcard before concretizing
@ -392,12 +393,16 @@ def concretize_variants(self, spec):
# multivalue variant, a concrete variant cannot have the value # multivalue variant, a concrete variant cannot have the value
# wildcard, and a wildcard does not constrain a variant # wildcard, and a wildcard does not constrain a variant
spec.variants.pop(name) spec.variants.pop(name)
if name not in spec.variants: if name not in spec.variants and any(spec.satisfies(w)
for w in when):
changed = True changed = True
if name in preferred_variants: if name in preferred_variants:
spec.variants[name] = preferred_variants.get(name) spec.variants[name] = preferred_variants.get(name)
else: else:
spec.variants[name] = variant.make_default() spec.variants[name] = variant.make_default()
if name in spec.variants and not any(spec.satisfies(w)
for w in when):
raise vt.InvalidVariantForSpecError(name, when, spec)
return changed return changed

View File

@ -244,7 +244,7 @@ def _wrapper(*args, **kwargs):
if DirectiveMeta._when_constraints_from_context: if DirectiveMeta._when_constraints_from_context:
# Check that directives not yet supporting the when= argument # Check that directives not yet supporting the when= argument
# are not used inside the context manager # are not used inside the context manager
if decorated_function.__name__ in ('version', 'variant'): if decorated_function.__name__ == 'version':
msg = ('directive "{0}" cannot be used within a "when"' msg = ('directive "{0}" cannot be used within a "when"'
' context since it does not support a "when=" ' ' context since it does not support a "when=" '
'argument') 'argument')
@ -562,7 +562,8 @@ def variant(
description='', description='',
values=None, values=None,
multi=None, multi=None,
validator=None): validator=None,
when=None):
"""Define a variant for the package. Packager can specify a default """Define a variant for the package. Packager can specify a default
value as well as a text description. value as well as a text description.
@ -581,6 +582,8 @@ def variant(
logic. It receives the package name, the variant name and a tuple logic. It receives the package name, the variant name and a tuple
of values and should raise an instance of SpackError if the group of values and should raise an instance of SpackError if the group
doesn't meet the additional constraints doesn't meet the additional constraints
when (spack.spec.Spec, bool): optional condition on which the
variant applies
Raises: Raises:
DirectiveError: if arguments passed to the directive are invalid DirectiveError: if arguments passed to the directive are invalid
@ -640,14 +643,23 @@ def _raise_default_not_set(pkg):
description = str(description).strip() description = str(description).strip()
def _execute_variant(pkg): def _execute_variant(pkg):
when_spec = make_when_spec(when)
when_specs = [when_spec]
if not re.match(spack.spec.identifier_re, name): if not re.match(spack.spec.identifier_re, name):
directive = 'variant' directive = 'variant'
msg = "Invalid variant name in {0}: '{1}'" msg = "Invalid variant name in {0}: '{1}'"
raise DirectiveError(directive, msg.format(pkg.name, name)) raise DirectiveError(directive, msg.format(pkg.name, name))
pkg.variants[name] = spack.variant.Variant( if name in pkg.variants:
# We accumulate when specs, but replace the rest of the variant
# with the newer values
_, orig_when = pkg.variants[name]
when_specs += orig_when
pkg.variants[name] = (spack.variant.Variant(
name, default, description, values, multi, validator name, default, description, values, multi, validator
) ), when_specs)
return _execute_variant return _execute_variant

View File

@ -721,8 +721,12 @@ def pkg_rules(self, pkg, tests):
self.gen.newline() self.gen.newline()
# variants # variants
for name, variant in sorted(pkg.variants.items()): for name, entry in sorted(pkg.variants.items()):
self.gen.fact(fn.variant(pkg.name, name)) variant, when = entry
for w in when:
cond_id = self.condition(w, name=pkg.name)
self.gen.fact(fn.variant_condition(cond_id, pkg.name, name))
single_value = not variant.multi single_value = not variant.multi
if single_value: if single_value:
@ -788,7 +792,7 @@ def condition(self, required_spec, imposed_spec=None, name=None):
Arguments: Arguments:
required_spec (spack.spec.Spec): the spec that triggers this condition required_spec (spack.spec.Spec): the spec that triggers this condition
imposed_spec (spack.spec.Spec or None): the sepc with constraints that imposed_spec (spack.spec.Spec or None): the spec with constraints that
are imposed when this condition is triggered are imposed when this condition is triggered
name (str or None): name for `required_spec` (required if name (str or None): name for `required_spec` (required if
required_spec is anonymous, ignored if not) required_spec is anonymous, ignored if not)
@ -1087,7 +1091,7 @@ class Body(object):
reserved_names = spack.directives.reserved_names reserved_names = spack.directives.reserved_names
if not spec.virtual and vname not in reserved_names: if not spec.virtual and vname not in reserved_names:
try: try:
variant_def = spec.package.variants[vname] variant_def, _ = spec.package.variants[vname]
except KeyError: except KeyError:
msg = 'variant "{0}" not found in package "{1}"' msg = 'variant "{0}" not found in package "{1}"'
raise RuntimeError(msg.format(vname, spec.name)) raise RuntimeError(msg.format(vname, spec.name))

View File

@ -350,6 +350,20 @@ external_conditions_hold(Package, LocalIndex) :-
%----------------------------------------------------------------------------- %-----------------------------------------------------------------------------
% Variant semantics % Variant semantics
%----------------------------------------------------------------------------- %-----------------------------------------------------------------------------
% a variant is a variant of a package if it is a variant under some condition
% and that condition holds
variant(Package, Variant) :- variant_condition(ID, Package, Variant),
condition_holds(ID).
% a variant cannot be set if it is not a variant on the package
:- variant_set(Package, Variant),
not variant(Package, Variant),
error("Unsatisfied conditional variants cannot be set").
% a variant cannot take on a value if it is not a variant of the package
:- variant_value(Package, Variant, _), not variant(Package, Variant),
error("Unsatisfied conditional variants cannot take on a variant value").
% one variant value for single-valued variants. % one variant value for single-valued variants.
1 { 1 {
variant_value(Package, Variant, Value) variant_value(Package, Variant, Value)

View File

@ -3081,7 +3081,7 @@ def update_variant_validate(self, variant_name, values):
if not isinstance(values, tuple): if not isinstance(values, tuple):
values = (values,) values = (values,)
pkg_variant = self.package_class.variants[variant_name] pkg_variant, _ = self.package_class.variants[variant_name]
for value in values: for value in values:
if self.variants.get(variant_name): if self.variants.get(variant_name):

View File

@ -16,6 +16,7 @@
import spack.error import spack.error
import spack.platforms import spack.platforms
import spack.repo import spack.repo
import spack.variant as vt
from spack.concretize import find_spec from spack.concretize import find_spec
from spack.spec import Spec from spack.spec import Spec
from spack.util.mock_package import MockPackageMultiRepo from spack.util.mock_package import MockPackageMultiRepo
@ -738,6 +739,41 @@ def test_compiler_conflicts_in_package_py(self, spec_str, expected_str):
s = Spec(spec_str).concretized() s = Spec(spec_str).concretized()
assert s.satisfies(expected_str) assert s.satisfies(expected_str)
@pytest.mark.parametrize('spec_str,expected,unexpected', [
('conditional-variant-pkg@1.0',
['two_whens'],
['version_based', 'variant_based']),
('conditional-variant-pkg@2.0',
['version_based', 'variant_based'],
['two_whens']),
('conditional-variant-pkg@2.0~version_based',
['version_based'],
['variant_based', 'two_whens']),
('conditional-variant-pkg@2.0+version_based+variant_based',
['version_based', 'variant_based', 'two_whens'],
[])
])
def test_conditional_variants(self, spec_str, expected, unexpected):
s = Spec(spec_str).concretized()
for var in expected:
assert s.satisfies('%s=*' % var)
for var in unexpected:
assert not s.satisfies('%s=*' % var)
@pytest.mark.parametrize('bad_spec', [
'@1.0~version_based',
'@1.0+version_based',
'@2.0~version_based+variant_based',
'@2.0+version_based~variant_based+two_whens',
])
def test_conditional_variants_fail(self, bad_spec):
with pytest.raises(
(spack.error.UnsatisfiableSpecError,
vt.InvalidVariantForSpecError)
):
_ = Spec('conditional-variant-pkg' + bad_spec).concretized()
@pytest.mark.parametrize('spec_str,expected,unexpected', [ @pytest.mark.parametrize('spec_str,expected,unexpected', [
('py-extension3 ^python@3.5.1', [], ['py-extension1']), ('py-extension3 ^python@3.5.1', [], ['py-extension1']),
('py-extension3 ^python@2.7.11', ['py-extension1'], []), ('py-extension3 ^python@2.7.11', ['py-extension1'], []),

View File

@ -247,7 +247,8 @@ def test_variant_defaults_are_parsable_from_cli():
"""Ensures that variant defaults are parsable from cli.""" """Ensures that variant defaults are parsable from cli."""
failing = [] failing = []
for pkg in spack.repo.path.all_packages(): for pkg in spack.repo.path.all_packages():
for variant_name, variant in pkg.variants.items(): for variant_name, entry in pkg.variants.items():
variant, _ = entry
default_is_parsable = ( default_is_parsable = (
# Permitting a default that is an instance on 'int' permits # Permitting a default that is an instance on 'int' permits
# to have foo=false or foo=0. Other falsish values are # to have foo=false or foo=0. Other falsish values are
@ -262,7 +263,8 @@ def test_variant_defaults_are_parsable_from_cli():
def test_variant_defaults_listed_explicitly_in_values(): def test_variant_defaults_listed_explicitly_in_values():
failing = [] failing = []
for pkg in spack.repo.path.all_packages(): for pkg in spack.repo.path.all_packages():
for variant_name, variant in pkg.variants.items(): for variant_name, entry in pkg.variants.items():
variant, _ = entry
vspec = variant.make_default() vspec = variant.make_default()
try: try:
variant.validate_or_raise(vspec, pkg=pkg) variant.validate_or_raise(vspec, pkg=pkg)

View File

@ -181,6 +181,17 @@ def variant_cls(self):
return BoolValuedVariant return BoolValuedVariant
return SingleValuedVariant return SingleValuedVariant
def __eq__(self, other):
return (self.name == other.name and
self.default == other.default and
self.values == other.values and
self.multi == other.multi and
self.single_value_validator == other.single_value_validator and
self.group_validator == other.group_validator)
def __ne__(self, other):
return not self == other
def implicit_variant_conversion(method): def implicit_variant_conversion(method):
"""Converts other to type(self) and calls method(self, other) """Converts other to type(self) and calls method(self, other)
@ -645,10 +656,10 @@ def substitute_abstract_variants(spec):
new_variant = SingleValuedVariant(name, v._original_value) new_variant = SingleValuedVariant(name, v._original_value)
spec.variants.substitute(new_variant) spec.variants.substitute(new_variant)
continue continue
pkg_variant = spec.package_class.variants.get(name, None) if name not in spec.package_class.variants:
if not pkg_variant:
failed.append(name) failed.append(name)
continue continue
pkg_variant, _ = spec.package_class.variants[name]
new_variant = pkg_variant.make_variant(v._original_value) new_variant = pkg_variant.make_variant(v._original_value)
pkg_variant.validate_or_raise(new_variant, spec.package_class) pkg_variant.validate_or_raise(new_variant, spec.package_class)
spec.variants.substitute(new_variant) spec.variants.substitute(new_variant)
@ -880,6 +891,16 @@ def __init__(self, variant, invalid_values, pkg):
) )
class InvalidVariantForSpecError(error.SpecError):
"""Raised when an invalid conditional variant is specified."""
def __init__(self, variant, when, spec):
msg = "Invalid variant {0} for spec {1}.\n"
msg += "{0} is only available for {1.name} when satisfying one of {2}."
super(InvalidVariantForSpecError, self).__init__(
msg.format(variant, spec, when)
)
class UnsatisfiableVariantSpecError(error.UnsatisfiableSpecError): class UnsatisfiableVariantSpecError(error.UnsatisfiableSpecError):
"""Raised when a spec variant conflicts with package constraints.""" """Raised when a spec variant conflicts with package constraints."""

View File

@ -0,0 +1,24 @@
# Copyright 2013-2021 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
class ConditionalVariantPkg(Package):
"""This package is used to test conditional variants."""
homepage = "http://www.example.com/conditional-variant-pkg"
url = "http://www.unit-test-should-replace-this-url/conditional-variant-1.0.tar.gz"
version('1.0', '0123456789abcdef0123456789abcdef')
version('2.0', 'abcdef0123456789abcdef0123456789')
variant('version_based', default=True, when='@2.0:',
description="Check that version constraints work")
variant('variant_based', default=False, when='+version_based',
description="Check that variants can depend on variants")
variant('two_whens', default=False, when='@1.0')
variant('two_whens', default=False, when='+variant_based')
def install(self, spec, prefix):
assert False