variants: fix narrowing multi -> single -> bool (#49880)

* `x=*` constrained by `+x` now produces a boolean valued variant instead of a multi-valued variant.
   
* Values are now always stored as a tuple internally, whether bool, single or multi-valued. 

* Value assignment has a stricter api to prevent ambiguity / type issues related to 
   `variant.value = "x"` / `variant.value = ["x"]` / `variant.value = ("x",)`. It's now `variant.set("x", ...)` for 
   single and multi-valued variants.

* The `_original_value` prop is dropped, since it was unused.

* The wildcard `*` is no longer a possible variant value in any type of variant, since the *parser*
   deals with it and creates a variant with no values.
This commit is contained in:
Harmen Stoppels 2025-04-16 09:44:38 +02:00 committed by GitHub
parent 1dc9bac745
commit 883bbf3826
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 293 additions and 359 deletions

View File

@ -1755,15 +1755,17 @@ def define_variant(
pkg_fact(fn.variant_condition(name, vid, cond_id)) pkg_fact(fn.variant_condition(name, vid, cond_id))
# record type so we can construct the variant when we read it back in # record type so we can construct the variant when we read it back in
self.gen.fact(fn.variant_type(vid, variant_def.variant_type.value)) self.gen.fact(fn.variant_type(vid, variant_def.variant_type.string))
if variant_def.sticky: if variant_def.sticky:
pkg_fact(fn.variant_sticky(vid)) pkg_fact(fn.variant_sticky(vid))
# define defaults for this variant definition # define defaults for this variant definition
defaults = variant_def.make_default().value if variant_def.multi else [variant_def.default] if variant_def.multi:
for val in sorted(defaults): for val in sorted(variant_def.make_default().values):
pkg_fact(fn.variant_default_value_from_package_py(vid, val)) pkg_fact(fn.variant_default_value_from_package_py(vid, val))
else:
pkg_fact(fn.variant_default_value_from_package_py(vid, variant_def.default))
# define possible values for this variant definition # define possible values for this variant definition
values = variant_def.values values = variant_def.values
@ -1791,7 +1793,9 @@ def define_variant(
# make a spec indicating whether the variant has this conditional value # make a spec indicating whether the variant has this conditional value
variant_has_value = spack.spec.Spec() variant_has_value = spack.spec.Spec()
variant_has_value.variants[name] = vt.VariantBase(name, value.value) variant_has_value.variants[name] = vt.VariantValue(
vt.VariantType.MULTI, name, (value.value,)
)
if value.when: if value.when:
# the conditional value is always "possible", but it imposes its when condition as # the conditional value is always "possible", but it imposes its when condition as
@ -2373,7 +2377,7 @@ def preferred_variants(self, pkg_name):
) )
continue continue
for value in variant.value_as_tuple: for value in variant.values:
for variant_def in variant_defs: for variant_def in variant_defs:
self.variant_values_from_specs.add((pkg_name, id(variant_def), value)) self.variant_values_from_specs.add((pkg_name, id(variant_def), value))
self.gen.fact( self.gen.fact(
@ -2491,7 +2495,7 @@ def _spec_clauses(
if variant.value == ("*",): if variant.value == ("*",):
continue continue
for value in variant.value_as_tuple: for value in variant.values:
# ensure that the value *can* be valid for the spec # ensure that the value *can* be valid for the spec
if spec.name and not spec.concrete and not spack.repo.PATH.is_virtual(spec.name): if spec.name and not spec.concrete and not spack.repo.PATH.is_virtual(spec.name):
variant_defs = vt.prevalidate_variant_value( variant_defs = vt.prevalidate_variant_value(
@ -3828,13 +3832,13 @@ def node_os(self, node, os):
def node_target(self, node, target): def node_target(self, node, target):
self._arch(node).target = target self._arch(node).target = target
def variant_selected(self, node, name, value, variant_type, variant_id): def variant_selected(self, node, name: str, value: str, variant_type: str, variant_id):
spec = self._specs[node] spec = self._specs[node]
variant = spec.variants.get(name) variant = spec.variants.get(name)
if not variant: if not variant:
spec.variants[name] = vt.VariantType(variant_type).variant_class(name, value) spec.variants[name] = vt.VariantValue.from_concretizer(name, value, variant_type)
else: else:
assert variant_type == vt.VariantType.MULTI.value, ( assert variant_type == "multi", (
f"Can't have multiple values for single-valued variant: " f"Can't have multiple values for single-valued variant: "
f"{node}, {name}, {value}, {variant_type}, {variant_id}" f"{node}, {name}, {value}, {variant_type}, {variant_id}"
) )
@ -4213,10 +4217,10 @@ def _inject_patches_variant(root: spack.spec.Spec) -> None:
continue continue
patches = list(spec_to_patches[id(spec)]) patches = list(spec_to_patches[id(spec)])
variant: vt.MultiValuedVariant = spec.variants.setdefault( variant: vt.VariantValue = spec.variants.setdefault(
"patches", vt.MultiValuedVariant("patches", ()) "patches", vt.MultiValuedVariant("patches", ())
) )
variant.value = tuple(p.sha256 for p in patches) variant.set(*(p.sha256 for p in patches))
# FIXME: Monkey patches variant to store patches order # FIXME: Monkey patches variant to store patches order
ordered_hashes = [(*p.ordering_key, p.sha256) for p in patches if p.ordering_key] ordered_hashes = [(*p.ordering_key, p.sha256) for p in patches if p.ordering_key]
ordered_hashes.sort() ordered_hashes.sort()

View File

@ -1698,7 +1698,9 @@ def _dependencies_dict(self, depflag: dt.DepFlag = dt.ALL):
result[key] = list(group) result[key] = list(group)
return result return result
def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> None: def _add_flag(
self, name: str, value: Union[str, bool], propagate: bool, concrete: bool
) -> None:
"""Called by the parser to add a known flag""" """Called by the parser to add a known flag"""
if propagate and name in vt.RESERVED_NAMES: if propagate and name in vt.RESERVED_NAMES:
@ -1708,6 +1710,7 @@ def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> N
valid_flags = FlagMap.valid_compiler_flags() valid_flags = FlagMap.valid_compiler_flags()
if name == "arch" or name == "architecture": if name == "arch" or name == "architecture":
assert type(value) is str, "architecture have a string value"
parts = tuple(value.split("-")) parts = tuple(value.split("-"))
plat, os, tgt = parts if len(parts) == 3 else (None, None, value) plat, os, tgt = parts if len(parts) == 3 else (None, None, value)
self._set_architecture(platform=plat, os=os, target=tgt) self._set_architecture(platform=plat, os=os, target=tgt)
@ -1721,17 +1724,15 @@ def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> N
self.namespace = value self.namespace = value
elif name in valid_flags: elif name in valid_flags:
assert self.compiler_flags is not None assert self.compiler_flags is not None
assert type(value) is str, f"{name} must have a string value"
flags_and_propagation = spack.compilers.flags.tokenize_flags(value, propagate) flags_and_propagation = spack.compilers.flags.tokenize_flags(value, propagate)
flag_group = " ".join(x for (x, y) in flags_and_propagation) flag_group = " ".join(x for (x, y) in flags_and_propagation)
for flag, propagation in flags_and_propagation: for flag, propagation in flags_and_propagation:
self.compiler_flags.add_flag(name, flag, propagation, flag_group) self.compiler_flags.add_flag(name, flag, propagation, flag_group)
else: else:
if str(value).upper() == "TRUE" or str(value).upper() == "FALSE": self.variants[name] = vt.VariantValue.from_string_or_bool(
self.variants[name] = vt.BoolValuedVariant(name, value, propagate) name, value, propagate=propagate, concrete=concrete
elif concrete: )
self.variants[name] = vt.MultiValuedVariant(name, value, propagate)
else:
self.variants[name] = vt.VariantBase(name, value, propagate)
def _set_architecture(self, **kwargs): def _set_architecture(self, **kwargs):
"""Called by the parser to set the architecture.""" """Called by the parser to set the architecture."""
@ -4481,7 +4482,7 @@ def __init__(self, spec: Spec):
def __setitem__(self, name, vspec): def __setitem__(self, name, vspec):
# Raise a TypeError if vspec is not of the right type # Raise a TypeError if vspec is not of the right type
if not isinstance(vspec, vt.VariantBase): if not isinstance(vspec, vt.VariantValue):
raise TypeError( raise TypeError(
"VariantMap accepts only values of variant types " "VariantMap accepts only values of variant types "
f"[got {type(vspec).__name__} instead]" f"[got {type(vspec).__name__} instead]"
@ -4621,7 +4622,7 @@ def __str__(self):
bool_keys = [] bool_keys = []
kv_keys = [] kv_keys = []
for key in sorted_keys: for key in sorted_keys:
if isinstance(self[key].value, bool): if self[key].type == vt.VariantType.BOOL:
bool_keys.append(key) bool_keys.append(key)
else: else:
kv_keys.append(key) kv_keys.append(key)
@ -4654,7 +4655,8 @@ def substitute_abstract_variants(spec: Spec):
unknown = [] unknown = []
for name, v in spec.variants.items(): for name, v in spec.variants.items():
if name == "dev_path": if name == "dev_path":
spec.variants.substitute(vt.SingleValuedVariant(name, v._original_value)) v.type = vt.VariantType.SINGLE
v.concrete = True
continue continue
elif name in vt.RESERVED_NAMES: elif name in vt.RESERVED_NAMES:
continue continue
@ -4677,7 +4679,7 @@ def substitute_abstract_variants(spec: Spec):
if rest: if rest:
continue continue
new_variant = pkg_variant.make_variant(v._original_value) new_variant = pkg_variant.make_variant(*v.values)
pkg_variant.validate_or_raise(new_variant, spec.name) pkg_variant.validate_or_raise(new_variant, spec.name)
spec.variants.substitute(new_variant) spec.variants.substitute(new_variant)
@ -4803,7 +4805,7 @@ def from_node_dict(cls, node):
for val in values: for val in values:
spec.compiler_flags.add_flag(name, val, propagate) spec.compiler_flags.add_flag(name, val, propagate)
else: else:
spec.variants[name] = vt.MultiValuedVariant.from_node_dict( spec.variants[name] = vt.VariantValue.from_node_dict(
name, values, propagate=propagate, abstract=name in abstract_variants name, values, propagate=propagate, abstract=name in abstract_variants
) )
@ -4829,7 +4831,7 @@ def from_node_dict(cls, node):
patches = node["patches"] patches = node["patches"]
if len(patches) > 0: if len(patches) > 0:
mvar = spec.variants.setdefault("patches", vt.MultiValuedVariant("patches", ())) mvar = spec.variants.setdefault("patches", vt.MultiValuedVariant("patches", ()))
mvar.value = patches mvar.set(*patches)
# FIXME: Monkey patches mvar to store patches order # FIXME: Monkey patches mvar to store patches order
mvar._patches_in_order_of_appearance = patches mvar._patches_in_order_of_appearance = patches

View File

@ -62,7 +62,7 @@
import sys import sys
import traceback import traceback
import warnings import warnings
from typing import Iterator, List, Optional, Tuple from typing import Iterator, List, Optional, Tuple, Union
from llnl.util.tty import color from llnl.util.tty import color
@ -369,7 +369,7 @@ def raise_parsing_error(string: str, cause: Optional[Exception] = None):
"""Raise a spec parsing error with token context.""" """Raise a spec parsing error with token context."""
raise SpecParsingError(string, self.ctx.current_token, self.literal_str) from cause raise SpecParsingError(string, self.ctx.current_token, self.literal_str) from cause
def add_flag(name: str, value: str, propagate: bool, concrete: bool): def add_flag(name: str, value: Union[str, bool], propagate: bool, concrete: bool):
"""Wrapper around ``Spec._add_flag()`` that adds parser context to errors raised.""" """Wrapper around ``Spec._add_flag()`` that adds parser context to errors raised."""
try: try:
initial_spec._add_flag(name, value, propagate, concrete) initial_spec._add_flag(name, value, propagate, concrete)

View File

@ -973,13 +973,10 @@ def test_spec_formatting_bad_formats(self, default_mock_concretization, fmt_str)
with pytest.raises(SpecFormatStringError): with pytest.raises(SpecFormatStringError):
spec.format(fmt_str) spec.format(fmt_str)
def test_combination_of_wildcard_or_none(self): def test_wildcard_is_invalid_variant_value(self):
# Test that using 'none' and another value raises """The spec string x=* is parsed as a multi-valued variant with values the empty set.
with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot be combined"): That excludes * as a literal variant value."""
Spec("multivalue-variant foo=none,bar") with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot use reserved value"):
# Test that using wildcard and another value raises
with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot be combined"):
Spec("multivalue-variant foo=*,bar") Spec("multivalue-variant foo=*,bar")
def test_errors_in_variant_directive(self): def test_errors_in_variant_directive(self):

View File

@ -21,7 +21,7 @@
SingleValuedVariant, SingleValuedVariant,
UnsatisfiableVariantSpecError, UnsatisfiableVariantSpecError,
Variant, Variant,
VariantBase, VariantValue,
disjoint_sets, disjoint_sets,
) )
@ -29,53 +29,37 @@
class TestMultiValuedVariant: class TestMultiValuedVariant:
def test_initialization(self): def test_initialization(self):
# Basic properties # Basic properties
a = MultiValuedVariant("foo", "bar,baz") a = MultiValuedVariant("foo", ("bar", "baz"))
assert repr(a) == "MultiValuedVariant('foo', 'bar,baz')"
assert str(a) == "foo:=bar,baz" assert str(a) == "foo:=bar,baz"
assert a.values == ("bar", "baz")
assert a.value == ("bar", "baz") assert a.value == ("bar", "baz")
assert "bar" in a assert "bar" in a
assert "baz" in a assert "baz" in a
assert eval(repr(a)) == a
# Spaces are trimmed
b = MultiValuedVariant("foo", "bar, baz")
assert repr(b) == "MultiValuedVariant('foo', 'bar, baz')"
assert str(b) == "foo:=bar,baz"
assert b.value == ("bar", "baz")
assert "bar" in b
assert "baz" in b
assert a == b
assert hash(a) == hash(b)
assert eval(repr(b)) == a
# Order is not important # Order is not important
c = MultiValuedVariant("foo", "baz, bar") c = MultiValuedVariant("foo", ("baz", "bar"))
assert repr(c) == "MultiValuedVariant('foo', 'baz, bar')"
assert str(c) == "foo:=bar,baz" assert str(c) == "foo:=bar,baz"
assert c.value == ("bar", "baz") assert c.values == ("bar", "baz")
assert "bar" in c assert "bar" in c
assert "baz" in c assert "baz" in c
assert a == c assert a == c
assert hash(a) == hash(c) assert hash(a) == hash(c)
assert eval(repr(c)) == a
# Check the copy # Check the copy
d = a.copy() d = a.copy()
assert repr(a) == repr(d)
assert str(a) == str(d) assert str(a) == str(d)
assert d.value == ("bar", "baz") assert d.values == ("bar", "baz")
assert "bar" in d assert "bar" in d
assert "baz" in d assert "baz" in d
assert a == d assert a == d
assert a is not d assert a is not d
assert hash(a) == hash(d) assert hash(a) == hash(d)
assert eval(repr(d)) == a
def test_satisfies(self): def test_satisfies(self):
a = MultiValuedVariant("foo", "bar,baz") a = MultiValuedVariant("foo", ("bar", "baz"))
b = MultiValuedVariant("foo", "bar") b = MultiValuedVariant("foo", ("bar",))
c = MultiValuedVariant("fee", "bar,baz") c = MultiValuedVariant("fee", ("bar", "baz"))
d = MultiValuedVariant("foo", "True") d = MultiValuedVariant("foo", (True,))
# concrete, different values do not satisfy each other # concrete, different values do not satisfy each other
assert not a.satisfies(b) and not b.satisfies(a) assert not a.satisfies(b) and not b.satisfies(a)
@ -85,21 +69,19 @@ def test_satisfies(self):
# eachother # eachother
b_sv = SingleValuedVariant("foo", "bar") b_sv = SingleValuedVariant("foo", "bar")
assert b.satisfies(b_sv) and b_sv.satisfies(b) assert b.satisfies(b_sv) and b_sv.satisfies(b)
d_sv = SingleValuedVariant("foo", "True") d_sv = SingleValuedVariant("foo", True)
assert d.satisfies(d_sv) and d_sv.satisfies(d) assert d.satisfies(d_sv) and d_sv.satisfies(d)
almost_d_bv = SingleValuedVariant("foo", "true") almost_d_bv = SingleValuedVariant("foo", True)
assert not d.satisfies(almost_d_bv) assert d.satisfies(almost_d_bv)
# BoolValuedVariant actually stores the value as a boolean, whereas with MV and SV the d_bv = BoolValuedVariant("foo", True)
# value is string "True". assert d.satisfies(d_bv) and d_bv.satisfies(d)
d_bv = BoolValuedVariant("foo", "True")
assert not d.satisfies(d_bv) and not d_bv.satisfies(d)
def test_intersects(self): def test_intersects(self):
a = MultiValuedVariant("foo", "bar,baz") a = MultiValuedVariant("foo", ("bar", "baz"))
b = MultiValuedVariant("foo", "True") b = MultiValuedVariant("foo", (True,))
c = MultiValuedVariant("fee", "bar,baz") c = MultiValuedVariant("fee", ("bar", "baz"))
d = MultiValuedVariant("foo", "bar,barbaz") d = MultiValuedVariant("foo", ("bar", "barbaz"))
# concrete, different values do not intersect. # concrete, different values do not intersect.
assert not a.intersects(b) and not b.intersects(a) assert not a.intersects(b) and not b.intersects(a)
@ -110,47 +92,45 @@ def test_intersects(self):
assert not c.intersects(d) and not d.intersects(c) assert not c.intersects(d) and not d.intersects(c)
# SV and MV intersect if they have the same concrete value. # SV and MV intersect if they have the same concrete value.
b_sv = SingleValuedVariant("foo", "True") b_sv = SingleValuedVariant("foo", True)
assert b.intersects(b_sv) assert b.intersects(b_sv)
assert not c.intersects(b_sv) assert not c.intersects(b_sv)
# BoolValuedVariant stores a bool, which is not the same as the string "True" in MV. # BoolValuedVariant intersects if the value is the same
b_bv = BoolValuedVariant("foo", "True") b_bv = BoolValuedVariant("foo", True)
assert not b.intersects(b_bv) assert b.intersects(b_bv)
assert not c.intersects(b_bv) assert not c.intersects(b_bv)
def test_constrain(self): def test_constrain(self):
# Concrete values cannot be constrained # Concrete values cannot be constrained
a = MultiValuedVariant("foo", "bar,baz") a = MultiValuedVariant("foo", ("bar", "baz"))
b = MultiValuedVariant("foo", "bar") b = MultiValuedVariant("foo", ("bar",))
with pytest.raises(UnsatisfiableVariantSpecError): with pytest.raises(UnsatisfiableVariantSpecError):
a.constrain(b) a.constrain(b)
with pytest.raises(UnsatisfiableVariantSpecError): with pytest.raises(UnsatisfiableVariantSpecError):
b.constrain(a) b.constrain(a)
# Try to constrain on the same value # Try to constrain on the same value
a = MultiValuedVariant("foo", "bar,baz") a = MultiValuedVariant("foo", ("bar", "baz"))
b = a.copy() b = a.copy()
assert not a.constrain(b) assert not a.constrain(b)
assert a == b == MultiValuedVariant("foo", "bar,baz") assert a == b == MultiValuedVariant("foo", ("bar", "baz"))
# Try to constrain on a different name # Try to constrain on a different name
a = MultiValuedVariant("foo", "bar,baz") a = MultiValuedVariant("foo", ("bar", "baz"))
b = MultiValuedVariant("fee", "bar") b = MultiValuedVariant("fee", ("bar",))
with pytest.raises(UnsatisfiableVariantSpecError): with pytest.raises(UnsatisfiableVariantSpecError):
a.constrain(b) a.constrain(b)
def test_yaml_entry(self): def test_yaml_entry(self):
a = MultiValuedVariant("foo", "bar,baz,barbaz") a = MultiValuedVariant("foo", ("bar", "baz", "barbaz"))
b = MultiValuedVariant("foo", "bar, baz, barbaz") expected = ("foo", sorted(("bar", "baz", "barbaz")))
expected = ("foo", sorted(["bar", "baz", "barbaz"]))
assert a.yaml_entry() == expected assert a.yaml_entry() == expected
assert b.yaml_entry() == expected
a = MultiValuedVariant("foo", "bar") a = MultiValuedVariant("foo", ("bar",))
expected = ("foo", sorted(["bar"])) expected = ("foo", sorted(["bar"]))
assert a.yaml_entry() == expected assert a.yaml_entry() == expected
@ -160,26 +140,20 @@ class TestSingleValuedVariant:
def test_initialization(self): def test_initialization(self):
# Basic properties # Basic properties
a = SingleValuedVariant("foo", "bar") a = SingleValuedVariant("foo", "bar")
assert repr(a) == "SingleValuedVariant('foo', 'bar')"
assert str(a) == "foo=bar" assert str(a) == "foo=bar"
assert a.values == ("bar",)
assert a.value == "bar" assert a.value == "bar"
assert "bar" in a assert "bar" in a
assert eval(repr(a)) == a
# Raise if multiple values are passed
with pytest.raises(ValueError):
SingleValuedVariant("foo", "bar, baz")
# Check the copy # Check the copy
b = a.copy() b = a.copy()
assert repr(a) == repr(b)
assert str(a) == str(b) assert str(a) == str(b)
assert b.values == ("bar",)
assert b.value == "bar" assert b.value == "bar"
assert "bar" in b assert "bar" in b
assert a == b assert a == b
assert a is not b assert a is not b
assert hash(a) == hash(b) assert hash(a) == hash(b)
assert eval(repr(b)) == a
def test_satisfies(self): def test_satisfies(self):
a = SingleValuedVariant("foo", "bar") a = SingleValuedVariant("foo", "bar")
@ -247,54 +221,37 @@ def test_yaml_entry(self):
class TestBoolValuedVariant: class TestBoolValuedVariant:
def test_initialization(self): def test_initialization(self):
# Basic properties - True value # Basic properties - True value
for v in (True, "True", "TRUE", "TrUe"): a = BoolValuedVariant("foo", True)
a = BoolValuedVariant("foo", v)
assert repr(a) == "BoolValuedVariant('foo', {0})".format(repr(v))
assert str(a) == "+foo" assert str(a) == "+foo"
assert a.value is True assert a.value is True
assert a.values == (True,)
assert True in a assert True in a
assert eval(repr(a)) == a
# Copy - True value # Copy - True value
b = a.copy() b = a.copy()
assert repr(a) == repr(b)
assert str(a) == str(b) assert str(a) == str(b)
assert b.value is True assert b.value is True
assert b.values == (True,)
assert True in b assert True in b
assert a == b assert a == b
assert a is not b assert a is not b
assert hash(a) == hash(b) assert hash(a) == hash(b)
assert eval(repr(b)) == a
# Basic properties - False value # Copy - False value
for v in (False, "False", "FALSE", "FaLsE"): a = BoolValuedVariant("foo", False)
a = BoolValuedVariant("foo", v)
assert repr(a) == "BoolValuedVariant('foo', {0})".format(repr(v))
assert str(a) == "~foo"
assert a.value is False
assert False in a
assert eval(repr(a)) == a
# Copy - True value
b = a.copy() b = a.copy()
assert repr(a) == repr(b)
assert str(a) == str(b) assert str(a) == str(b)
assert b.value is False assert b.value is False
assert b.values == (False,)
assert False in b assert False in b
assert a == b assert a == b
assert a is not b assert a is not b
assert eval(repr(b)) == a
# Invalid values
for v in ("bar", "bar,baz"):
with pytest.raises(ValueError):
BoolValuedVariant("foo", v)
def test_satisfies(self): def test_satisfies(self):
a = BoolValuedVariant("foo", True) a = BoolValuedVariant("foo", True)
b = BoolValuedVariant("foo", False) b = BoolValuedVariant("foo", False)
c = BoolValuedVariant("fee", False) c = BoolValuedVariant("fee", False)
d = BoolValuedVariant("foo", "True") d = BoolValuedVariant("foo", True)
# concrete, different values do not satisfy each other # concrete, different values do not satisfy each other
assert not a.satisfies(b) and not b.satisfies(a) assert not a.satisfies(b) and not b.satisfies(a)
@ -325,7 +282,7 @@ def test_intersects(self):
a = BoolValuedVariant("foo", True) a = BoolValuedVariant("foo", True)
b = BoolValuedVariant("fee", True) b = BoolValuedVariant("fee", True)
c = BoolValuedVariant("foo", False) c = BoolValuedVariant("foo", False)
d = BoolValuedVariant("foo", "True") d = BoolValuedVariant("foo", True)
# concrete, different values do not intersect each other # concrete, different values do not intersect each other
assert not a.intersects(b) and not b.intersects(a) assert not a.intersects(b) and not b.intersects(a)
@ -347,7 +304,7 @@ def test_intersects(self):
def test_constrain(self): def test_constrain(self):
# Try to constrain on a value equal to self # Try to constrain on a value equal to self
a = BoolValuedVariant("foo", "True") a = BoolValuedVariant("foo", True)
b = BoolValuedVariant("foo", True) b = BoolValuedVariant("foo", True)
assert not a.constrain(b) assert not a.constrain(b)
@ -375,24 +332,24 @@ def test_constrain(self):
assert a == BoolValuedVariant("foo", True) assert a == BoolValuedVariant("foo", True)
def test_yaml_entry(self): def test_yaml_entry(self):
a = BoolValuedVariant("foo", "True") a = BoolValuedVariant("foo", True)
expected = ("foo", True) expected = ("foo", True)
assert a.yaml_entry() == expected assert a.yaml_entry() == expected
a = BoolValuedVariant("foo", "False") a = BoolValuedVariant("foo", False)
expected = ("foo", False) expected = ("foo", False)
assert a.yaml_entry() == expected assert a.yaml_entry() == expected
def test_from_node_dict(): def test_from_node_dict():
a = MultiValuedVariant.from_node_dict("foo", ["bar"]) a = VariantValue.from_node_dict("foo", ["bar"])
assert type(a) is MultiValuedVariant assert a.type == spack.variant.VariantType.MULTI
a = MultiValuedVariant.from_node_dict("foo", "bar") a = VariantValue.from_node_dict("foo", "bar")
assert type(a) is SingleValuedVariant assert a.type == spack.variant.VariantType.SINGLE
a = MultiValuedVariant.from_node_dict("foo", "true") a = VariantValue.from_node_dict("foo", "true")
assert type(a) is BoolValuedVariant assert a.type == spack.variant.VariantType.BOOL
class TestVariant: class TestVariant:
@ -406,7 +363,7 @@ def test_validation(self):
# Multiple values are not allowed # Multiple values are not allowed
with pytest.raises(MultipleValuesInExclusiveVariantError): with pytest.raises(MultipleValuesInExclusiveVariantError):
vspec.value = "bar,baz" vspec.set("bar", "baz")
# Inconsistent vspec # Inconsistent vspec
vspec.name = "FOO" vspec.name = "FOO"
@ -415,10 +372,10 @@ def test_validation(self):
# Valid multi-value vspec # Valid multi-value vspec
a.multi = True a.multi = True
vspec = a.make_variant("bar,baz") vspec = a.make_variant("bar", "baz")
a.validate_or_raise(vspec, "test-package") a.validate_or_raise(vspec, "test-package")
# Add an invalid value # Add an invalid value
vspec.value = "bar,baz,barbaz" vspec.set("bar", "baz", "barbaz")
with pytest.raises(InvalidVariantValueError): with pytest.raises(InvalidVariantValueError):
a.validate_or_raise(vspec, "test-package") a.validate_or_raise(vspec, "test-package")
@ -429,12 +386,12 @@ def validator(x):
except ValueError: except ValueError:
return False return False
a = Variant("foo", default=1024, description="", values=validator, multi=False) a = Variant("foo", default="1024", description="", values=validator, multi=False)
vspec = a.make_default() vspec = a.make_default()
a.validate_or_raise(vspec, "test-package") a.validate_or_raise(vspec, "test-package")
vspec.value = 2056 vspec.set("2056")
a.validate_or_raise(vspec, "test-package") a.validate_or_raise(vspec, "test-package")
vspec.value = "foo" vspec.set("foo")
with pytest.raises(InvalidVariantValueError): with pytest.raises(InvalidVariantValueError):
a.validate_or_raise(vspec, "test-package") a.validate_or_raise(vspec, "test-package")
@ -464,9 +421,9 @@ def test_invalid_values(self) -> None:
a["foo"] = 2 a["foo"] = 2
# Duplicate variant # Duplicate variant
a["foo"] = MultiValuedVariant("foo", "bar,baz") a["foo"] = MultiValuedVariant("foo", ("bar", "baz"))
with pytest.raises(DuplicateVariantError): with pytest.raises(DuplicateVariantError):
a["foo"] = MultiValuedVariant("foo", "bar") a["foo"] = MultiValuedVariant("foo", ("bar",))
with pytest.raises(DuplicateVariantError): with pytest.raises(DuplicateVariantError):
a["foo"] = SingleValuedVariant("foo", "bar") a["foo"] = SingleValuedVariant("foo", "bar")
@ -476,7 +433,7 @@ def test_invalid_values(self) -> None:
# Non matching names between key and vspec.name # Non matching names between key and vspec.name
with pytest.raises(KeyError): with pytest.raises(KeyError):
a["bar"] = MultiValuedVariant("foo", "bar") a["bar"] = MultiValuedVariant("foo", ("bar",))
def test_set_item(self) -> None: def test_set_item(self) -> None:
# Check that all the three types of variants are accepted # Check that all the three types of variants are accepted
@ -484,7 +441,7 @@ def test_set_item(self) -> None:
a["foo"] = BoolValuedVariant("foo", True) a["foo"] = BoolValuedVariant("foo", True)
a["bar"] = SingleValuedVariant("bar", "baz") a["bar"] = SingleValuedVariant("bar", "baz")
a["foobar"] = MultiValuedVariant("foobar", "a, b, c, d, e") a["foobar"] = MultiValuedVariant("foobar", ("a", "b", "c", "d", "e"))
def test_substitute(self) -> None: def test_substitute(self) -> None:
# Check substitution of a key that exists # Check substitution of a key that exists
@ -500,13 +457,13 @@ def test_substitute(self) -> None:
def test_satisfies_and_constrain(self) -> None: def test_satisfies_and_constrain(self) -> None:
# foo=bar foobar=fee feebar=foo # foo=bar foobar=fee feebar=foo
a = VariantMap(Spec()) a = VariantMap(Spec())
a["foo"] = MultiValuedVariant("foo", "bar") a["foo"] = MultiValuedVariant("foo", ("bar",))
a["foobar"] = SingleValuedVariant("foobar", "fee") a["foobar"] = SingleValuedVariant("foobar", "fee")
a["feebar"] = SingleValuedVariant("feebar", "foo") a["feebar"] = SingleValuedVariant("feebar", "foo")
# foo=bar,baz foobar=fee shared=True # foo=bar,baz foobar=fee shared=True
b = VariantMap(Spec()) b = VariantMap(Spec())
b["foo"] = MultiValuedVariant("foo", "bar, baz") b["foo"] = MultiValuedVariant("foo", ("bar", "baz"))
b["foobar"] = SingleValuedVariant("foobar", "fee") b["foobar"] = SingleValuedVariant("foobar", "fee")
b["shared"] = BoolValuedVariant("shared", True) b["shared"] = BoolValuedVariant("shared", True)
@ -516,7 +473,7 @@ def test_satisfies_and_constrain(self) -> None:
# foo=bar,baz foobar=fee feebar=foo shared=True # foo=bar,baz foobar=fee feebar=foo shared=True
c = VariantMap(Spec()) c = VariantMap(Spec())
c["foo"] = MultiValuedVariant("foo", "bar, baz") c["foo"] = MultiValuedVariant("foo", ("bar", "baz"))
c["foobar"] = SingleValuedVariant("foobar", "fee") c["foobar"] = SingleValuedVariant("foobar", "fee")
c["feebar"] = SingleValuedVariant("feebar", "foo") c["feebar"] = SingleValuedVariant("feebar", "foo")
c["shared"] = BoolValuedVariant("shared", True) c["shared"] = BoolValuedVariant("shared", True)
@ -529,14 +486,14 @@ def test_copy(self) -> None:
a = VariantMap(Spec()) a = VariantMap(Spec())
a["foo"] = BoolValuedVariant("foo", True) a["foo"] = BoolValuedVariant("foo", True)
a["bar"] = SingleValuedVariant("bar", "baz") a["bar"] = SingleValuedVariant("bar", "baz")
a["foobar"] = MultiValuedVariant("foobar", "a, b, c, d, e") a["foobar"] = MultiValuedVariant("foobar", ("a", "b", "c", "d", "e"))
c = a.copy() c = a.copy()
assert a == c assert a == c
def test_str(self) -> None: def test_str(self) -> None:
c = VariantMap(Spec()) c = VariantMap(Spec())
c["foo"] = MultiValuedVariant("foo", "bar, baz") c["foo"] = MultiValuedVariant("foo", ("bar", "baz"))
c["foobar"] = SingleValuedVariant("foobar", "fee") c["foobar"] = SingleValuedVariant("foobar", "fee")
c["feebar"] = SingleValuedVariant("feebar", "foo") c["feebar"] = SingleValuedVariant("feebar", "foo")
c["shared"] = BoolValuedVariant("shared", True) c["shared"] = BoolValuedVariant("shared", True)
@ -635,10 +592,10 @@ def test_wild_card_valued_variants_equivalent_to_str():
several_arbitrary_values = ("doe", "re", "mi") several_arbitrary_values = ("doe", "re", "mi")
# "*" case # "*" case
wild_output = wild_var.make_variant(several_arbitrary_values) wild_output = wild_var.make_variant(*several_arbitrary_values)
wild_var.validate_or_raise(wild_output, "test-package") wild_var.validate_or_raise(wild_output, "test-package")
# str case # str case
str_output = str_var.make_variant(several_arbitrary_values) str_output = str_var.make_variant(*several_arbitrary_values)
str_var.validate_or_raise(str_output, "test-package") str_var.validate_or_raise(str_output, "test-package")
# equivalence each instance already validated # equivalence each instance already validated
assert str_output.value == wild_output.value assert str_output.value == wild_output.value
@ -760,21 +717,21 @@ def test_concretize_variant_default_with_multiple_defs(
"spec,variant_name,narrowed_type", "spec,variant_name,narrowed_type",
[ [
# dev_path is a special case # dev_path is a special case
("foo dev_path=/path/to/source", "dev_path", SingleValuedVariant), ("foo dev_path=/path/to/source", "dev_path", spack.variant.VariantType.SINGLE),
# reserved name: won't be touched # reserved name: won't be touched
("foo patches=2349dc44", "patches", VariantBase), ("foo patches=2349dc44", "patches", spack.variant.VariantType.MULTI),
# simple case -- one definition applies # simple case -- one definition applies
("variant-values@1.0 v=foo", "v", SingleValuedVariant), ("variant-values@1.0 v=foo", "v", spack.variant.VariantType.SINGLE),
# simple, but with bool valued variant # simple, but with bool valued variant
("pkg-a bvv=true", "bvv", BoolValuedVariant), ("pkg-a bvv=true", "bvv", spack.variant.VariantType.BOOL),
# takes the second definition, which overrides the single-valued one # takes the second definition, which overrides the single-valued one
("variant-values@2.0 v=bar", "v", MultiValuedVariant), ("variant-values@2.0 v=bar", "v", spack.variant.VariantType.MULTI),
], ],
) )
def test_substitute_abstract_variants_narrowing(mock_packages, spec, variant_name, narrowed_type): def test_substitute_abstract_variants_narrowing(mock_packages, spec, variant_name, narrowed_type):
spec = Spec(spec) spec = Spec(spec)
spack.spec.substitute_abstract_variants(spec) spack.spec.substitute_abstract_variants(spec)
assert type(spec.variants[variant_name]) is narrowed_type assert spec.variants[variant_name].type == narrowed_type
def test_substitute_abstract_variants_failure(mock_packages): def test_substitute_abstract_variants_failure(mock_packages):
@ -920,3 +877,12 @@ def test_patches_variant():
assert not Spec("patches:=abcdef").satisfies("patches:=ab") assert not Spec("patches:=abcdef").satisfies("patches:=ab")
assert not Spec("patches:=abcdef,xyz").satisfies("patches:=abc,xyz") assert not Spec("patches:=abcdef,xyz").satisfies("patches:=abc,xyz")
assert not Spec("patches:=abcdef").satisfies("patches:=abcdefghi") assert not Spec("patches:=abcdef").satisfies("patches:=abcdefghi")
def test_constrain_narrowing():
s = Spec("foo=*")
assert s.variants["foo"].type == spack.variant.VariantType.MULTI
assert not s.variants["foo"].concrete
s.constrain("+foo")
assert s.variants["foo"].type == spack.variant.VariantType.BOOL
assert s.variants["foo"].concrete

View File

@ -10,7 +10,6 @@
import functools import functools
import inspect import inspect
import itertools import itertools
import re
from typing import Any, Callable, Collection, Iterable, List, Optional, Tuple, Type, Union from typing import Any, Callable, Collection, Iterable, List, Optional, Tuple, Type, Union
import llnl.util.lang as lang import llnl.util.lang as lang
@ -33,24 +32,22 @@
"target", "target",
} }
special_variant_values = [None, "none", "*"]
class VariantType(enum.IntEnum):
class VariantType(enum.Enum):
"""Enum representing the three concrete variant types.""" """Enum representing the three concrete variant types."""
MULTI = "multi" BOOL = 1
BOOL = "bool" SINGLE = 2
SINGLE = "single" MULTI = 3
@property @property
def variant_class(self) -> Type: def string(self) -> str:
if self is self.MULTI: """Convert the variant type to a string."""
return MultiValuedVariant if self == VariantType.BOOL:
elif self is self.BOOL: return "bool"
return BoolValuedVariant elif self == VariantType.SINGLE:
else: return "single"
return SingleValuedVariant return "multi"
class Variant: class Variant:
@ -134,7 +131,7 @@ def isa_type(v):
self.sticky = sticky self.sticky = sticky
self.precedence = precedence self.precedence = precedence
def validate_or_raise(self, vspec: "VariantBase", pkg_name: str): def validate_or_raise(self, vspec: "VariantValue", pkg_name: str):
"""Validate a variant spec against this package variant. Raises an """Validate a variant spec against this package variant. Raises an
exception if any error is found. exception if any error is found.
@ -156,7 +153,7 @@ def validate_or_raise(self, vspec: "VariantBase", pkg_name: str):
raise InconsistentValidationError(vspec, self) raise InconsistentValidationError(vspec, self)
# If the value is exclusive there must be at most one # If the value is exclusive there must be at most one
value = vspec.value_as_tuple value = vspec.values
if not self.multi and len(value) != 1: if not self.multi and len(value) != 1:
raise MultipleValuesInExclusiveVariantError(vspec, pkg_name) raise MultipleValuesInExclusiveVariantError(vspec, pkg_name)
@ -191,27 +188,15 @@ def allowed_values(self):
v = docstring if docstring else "" v = docstring if docstring else ""
return v return v
def make_default(self): def make_default(self) -> "VariantValue":
"""Factory that creates a variant holding the default value. """Factory that creates a variant holding the default value(s)."""
variant = VariantValue.from_string_or_bool(self.name, self.default)
variant.type = self.variant_type
return variant
Returns: def make_variant(self, *value: Union[str, bool]) -> "VariantValue":
MultiValuedVariant or SingleValuedVariant or BoolValuedVariant: """Factory that creates a variant holding the value(s) passed."""
instance of the proper variant return VariantValue(self.variant_type, self.name, value)
"""
return self.make_variant(self.default)
def make_variant(self, value: Union[str, bool]) -> "VariantBase":
"""Factory that creates a variant holding the value passed as
a parameter.
Args:
value: value that will be hold by the variant
Returns:
MultiValuedVariant or SingleValuedVariant or BoolValuedVariant:
instance of the proper variant
"""
return self.variant_type.variant_class(self.name, value)
@property @property
def variant_type(self) -> VariantType: def variant_type(self) -> VariantType:
@ -254,116 +239,148 @@ def _flatten(values) -> Collection:
#: Type for value of a variant #: Type for value of a variant
ValueType = Union[str, bool, Tuple[Union[str, bool], ...]] ValueType = Tuple[Union[bool, str], ...]
#: Type of variant value when output for JSON, YAML, etc. #: Type of variant value when output for JSON, YAML, etc.
SerializedValueType = Union[str, bool, List[Union[str, bool]]] SerializedValueType = Union[str, bool, List[Union[str, bool]]]
@lang.lazy_lexicographic_ordering @lang.lazy_lexicographic_ordering
class VariantBase: class VariantValue:
"""A BaseVariant corresponds to a spec string of the form ``foo=bar`` or ``foo=bar,baz``. """A VariantValue is a key-value pair that represents a variant. It can have zero or more
It is a constraint on the spec and abstract in the sense that it must have **at least** these values. Values have set semantics, so they are unordered and unique. The variant type can
values -- concretization may add more values.""" be narrowed from multi to single to boolean, this limits the number of values that can be
stored in the variant. Multi-valued variants can either be concrete or abstract: abstract
means that the variant takes at least the values specified, but may take more when concretized.
Concrete means that the variant takes exactly the values specified. Lastly, a variant can be
marked as propagating, which means that it should be propagated to dependencies."""
name: str name: str
propagate: bool propagate: bool
_value: ValueType concrete: bool
_original_value: Any type: VariantType
_values: ValueType
def __init__(self, name: str, value: ValueType, propagate: bool = False) -> None: slots = ("name", "propagate", "concrete", "type", "_values")
def __init__(
self,
type: VariantType,
name: str,
value: ValueType,
*,
propagate: bool = False,
concrete: bool = False,
) -> None:
self.name = name self.name = name
self.type = type
self.propagate = propagate self.propagate = propagate
self.concrete = False # only multi-valued variants can be abstract
self.concrete = concrete or type in (VariantType.BOOL, VariantType.SINGLE)
# Invokes property setter # Invokes property setter
self.value = value self.set(*value)
@staticmethod @staticmethod
def from_node_dict( def from_node_dict(
name: str, value: Union[str, List[str]], *, propagate: bool = False, abstract: bool = False name: str, value: Union[str, List[str]], *, propagate: bool = False, abstract: bool = False
) -> "VariantBase": ) -> "VariantValue":
"""Reconstruct a variant from a node dict.""" """Reconstruct a variant from a node dict."""
if isinstance(value, list): if isinstance(value, list):
constructor = VariantBase if abstract else MultiValuedVariant return VariantValue(
mvar = constructor(name, (), propagate=propagate) VariantType.MULTI, name, tuple(value), propagate=propagate, concrete=not abstract
mvar._value = tuple(value) )
mvar._original_value = mvar._value
return mvar
# todo: is this necessary? not literal true / false in json/yaml?
elif str(value).upper() == "TRUE" or str(value).upper() == "FALSE": elif str(value).upper() == "TRUE" or str(value).upper() == "FALSE":
return BoolValuedVariant(name, value, propagate=propagate) return VariantValue(
VariantType.BOOL, name, (str(value).upper() == "TRUE",), propagate=propagate
)
return SingleValuedVariant(name, value, propagate=propagate) return VariantValue(VariantType.SINGLE, name, (value,), propagate=propagate)
@staticmethod
def from_string_or_bool(
name: str, value: Union[str, bool], *, propagate: bool = False, concrete: bool = False
) -> "VariantValue":
if value is True or value is False:
return VariantValue(VariantType.BOOL, name, (value,), propagate=propagate)
elif value.upper() in ("TRUE", "FALSE"):
return VariantValue(
VariantType.BOOL, name, (value.upper() == "TRUE",), propagate=propagate
)
elif value == "*":
return VariantValue(VariantType.MULTI, name, (), propagate=propagate)
return VariantValue(
VariantType.MULTI,
name,
tuple(value.split(",")),
propagate=propagate,
concrete=concrete,
)
@staticmethod
def from_concretizer(name: str, value: str, type: str) -> "VariantValue":
"""Reconstruct a variant from concretizer output."""
if type == "bool":
return VariantValue(VariantType.BOOL, name, (value == "True",))
elif type == "multi":
return VariantValue(VariantType.MULTI, name, (value,), concrete=True)
else:
return VariantValue(VariantType.SINGLE, name, (value,))
def yaml_entry(self) -> Tuple[str, SerializedValueType]: def yaml_entry(self) -> Tuple[str, SerializedValueType]:
"""Returns a key, value tuple suitable to be an entry in a yaml dict. """Returns a (key, value) tuple suitable to be an entry in a yaml dict.
Returns: Returns:
tuple: (name, value_representation) tuple: (name, value_representation)
""" """
return self.name, list(self.value_as_tuple) if self.type == VariantType.MULTI:
return self.name, list(self.values)
return self.name, self.values[0]
@property @property
def value_as_tuple(self) -> Tuple[Union[bool, str], ...]: def values(self) -> ValueType:
"""Getter for self.value that always returns a Tuple (even for single valued variants). return self._values
This makes it easy to iterate over possible values.
"""
if isinstance(self._value, (bool, str)):
return (self._value,)
return self._value
@property @property
def value(self) -> ValueType: def value(self) -> Union[ValueType, bool, str]:
"""Returns a tuple of strings containing the values stored in return self._values[0] if self.type != VariantType.MULTI else self._values
the variant.
Returns: def set(self, *value: Union[bool, str]) -> None:
tuple: values stored in the variant """Set the value(s) of the variant."""
""" if len(value) > 1:
return self._value value = tuple(sorted(set(value)))
@value.setter if self.type != VariantType.MULTI:
def value(self, value: ValueType) -> None: if len(value) != 1:
self._value_setter(value) raise MultipleValuesInExclusiveVariantError(self)
unwrapped = value[0]
if self.type == VariantType.BOOL and unwrapped not in (True, False):
raise ValueError(
f"cannot set a boolean variant to a value that is not a boolean: {unwrapped}"
)
def _value_setter(self, value: ValueType) -> None: if "*" in value:
# Store the original value raise InvalidVariantValueError("cannot use reserved value '*'")
self._original_value = value
if value == "*": self._values = value
self._value = ()
return
if not isinstance(value, (tuple, list)):
# Store a tuple of CSV string representations
# Tuple is necessary here instead of list because the
# values need to be hashed
value = tuple(re.split(r"\s*,\s*", str(value)))
for val in special_variant_values:
if val in value and len(value) > 1:
msg = "'%s' cannot be combined with other variant" % val
msg += " values."
raise InvalidVariantValueCombinationError(msg)
# With multi-value variants it is necessary
# to remove duplicates and give an order
# to a set
self._value = tuple(sorted(set(value)))
def _cmp_iter(self) -> Iterable: def _cmp_iter(self) -> Iterable:
yield self.name yield self.name
yield self.propagate yield self.propagate
yield from (str(v) for v in self.value_as_tuple) yield self.concrete
yield from (str(v) for v in self.values)
def copy(self) -> "VariantBase": def copy(self) -> "VariantValue":
variant = type(self)(self.name, self._original_value, self.propagate) return VariantValue(
variant.concrete = self.concrete self.type, self.name, self.values, propagate=self.propagate, concrete=self.concrete
return variant )
def satisfies(self, other: "VariantBase") -> bool: def satisfies(self, other: "VariantValue") -> bool:
"""The lhs satisfies the rhs if all possible concretizations of lhs are also """The lhs satisfies the rhs if all possible concretizations of lhs are also
possible concretizations of rhs.""" possible concretizations of rhs."""
if self.name != other.name: if self.name != other.name:
@ -376,139 +393,91 @@ def satisfies(self, other: "VariantBase") -> bool:
if self.name == "patches": if self.name == "patches":
return all( return all(
isinstance(v, str) isinstance(v, str)
and any(isinstance(w, str) and w.startswith(v) for w in self.value_as_tuple) and any(isinstance(w, str) and w.startswith(v) for w in self.values)
for v in other.value_as_tuple for v in other.values
) )
return all(v in self for v in other.value_as_tuple) return all(v in self for v in other.values)
if self.concrete: if self.concrete:
# both concrete: they must be equal # both concrete: they must be equal
return self.value_as_tuple == other.value_as_tuple return self.values == other.values
return False return False
def intersects(self, other: "VariantBase") -> bool: def intersects(self, other: "VariantValue") -> bool:
"""True iff there exists a concretization that satisfies both lhs and rhs.""" """True iff there exists a concretization that satisfies both lhs and rhs."""
if self.name != other.name: if self.name != other.name:
return False return False
if self.concrete: if self.concrete:
if other.concrete: if other.concrete:
return self.value_as_tuple == other.value_as_tuple return self.values == other.values
return all(v in self for v in other.value_as_tuple) return all(v in self for v in other.values)
if other.concrete: if other.concrete:
return all(v in other for v in self.value_as_tuple) return all(v in other for v in self.values)
# both abstract: the union is a valid concretization of both # both abstract: the union is a valid concretization of both
return True return True
def constrain(self, other: "VariantBase") -> bool: def constrain(self, other: "VariantValue") -> bool:
"""Constrain self with other if they intersect. Returns true iff self was changed.""" """Constrain self with other if they intersect. Returns true iff self was changed."""
if not self.intersects(other): if not self.intersects(other):
raise UnsatisfiableVariantSpecError(self, other) raise UnsatisfiableVariantSpecError(self, other)
old_value = self.value old_values = self.values
values = list(sorted({*self.value_as_tuple, *other.value_as_tuple})) self.set(*self.values, *other.values)
self._value_setter(",".join(str(v) for v in values)) changed = old_values != self.values
changed = old_value != self.value
if self.propagate and not other.propagate: if self.propagate and not other.propagate:
self.propagate = False self.propagate = False
changed = True changed = True
if not self.concrete and other.concrete: if not self.concrete and other.concrete:
self.concrete = True self.concrete = True
changed = True changed = True
if self.type > other.type:
self.type = other.type
changed = True
return changed return changed
def __contains__(self, item: Union[str, bool]) -> bool:
return item in self.value_as_tuple
def __repr__(self) -> str:
return f"{type(self).__name__}({repr(self.name)}, {repr(self._original_value)})"
def __str__(self) -> str:
concrete = ":" if self.concrete else ""
delim = "==" if self.propagate else "="
values_tuple = self.value_as_tuple
if values_tuple:
value_str = ",".join(str(v) for v in values_tuple)
else:
value_str = "*"
return f"{self.name}{concrete}{delim}{spack.spec_parser.quote_if_needed(value_str)}"
class MultiValuedVariant(VariantBase):
def __init__(self, name, value, propagate=False):
super().__init__(name, value, propagate)
self.concrete = True
def append(self, value: Union[str, bool]) -> None: def append(self, value: Union[str, bool]) -> None:
"""Add another value to this multi-valued variant.""" self.set(*self.values, value)
self._value = tuple(sorted((value,) + self.value_as_tuple))
self._original_value = ",".join(str(v) for v in self._value) def __contains__(self, item: Union[str, bool]) -> bool:
return item in self.values
def __str__(self) -> str:
# Special-case patches to not print the full 64 character sha256
if self.name == "patches":
values_str = ",".join(str(x)[:7] for x in self.value_as_tuple)
else:
values_str = ",".join(str(x) for x in self.value_as_tuple)
delim = "==" if self.propagate else "="
return f"{self.name}:{delim}{spack.spec_parser.quote_if_needed(values_str)}"
class SingleValuedVariant(VariantBase):
def __init__(self, name, value, propagate=False):
super().__init__(name, value, propagate)
self.concrete = True
def _value_setter(self, value: ValueType) -> None:
# Treat the value as a multi-valued variant
super()._value_setter(value)
# Then check if there's only a single value
values = self.value_as_tuple
if len(values) != 1:
raise MultipleValuesInExclusiveVariantError(self)
self._value = values[0]
def __contains__(self, item: ValueType) -> bool:
return item == self.value
def yaml_entry(self) -> Tuple[str, SerializedValueType]:
assert isinstance(self.value, (bool, str))
return self.name, self.value
def __str__(self) -> str:
delim = "==" if self.propagate else "="
return f"{self.name}{delim}{spack.spec_parser.quote_if_needed(str(self.value))}"
class BoolValuedVariant(SingleValuedVariant):
def __init__(self, name, value, propagate=False):
super().__init__(name, value, propagate)
self.concrete = True
def _value_setter(self, value: ValueType) -> None:
# Check the string representation of the value and turn
# it to a boolean
if str(value).upper() == "TRUE":
self._original_value = value
self._value = True
elif str(value).upper() == "FALSE":
self._original_value = value
self._value = False
else:
raise ValueError(
f'cannot construct a BoolValuedVariant for "{self.name}" from '
"a value that does not represent a bool"
)
def __contains__(self, item: ValueType) -> bool:
return item is self.value
def __str__(self) -> str: def __str__(self) -> str:
# boolean variants are printed +foo or ~foo
if self.type == VariantType.BOOL:
sigil = "+" if self.value else "~" sigil = "+" if self.value else "~"
if self.propagate: if self.propagate:
sigil *= 2 sigil *= 2
return f"{sigil}{self.name}" return f"{sigil}{self.name}"
# concrete multi-valued foo:=bar,baz
concrete = ":" if self.type == VariantType.MULTI and self.concrete else ""
delim = "==" if self.propagate else "="
if not self.values:
value_str = "*"
elif self.name == "patches" and self.concrete:
value_str = ",".join(str(x)[:7] for x in self.values)
else:
value_str = ",".join(str(x) for x in self.values)
return f"{self.name}{concrete}{delim}{spack.spec_parser.quote_if_needed(value_str)}"
def __repr__(self):
return (
f"VariantValue({self.type!r}, {self.name!r}, {self.values!r}, "
f"propagate={self.propagate!r}, concrete={self.concrete!r})"
)
def MultiValuedVariant(name: str, value: ValueType, propagate: bool = False) -> VariantValue:
return VariantValue(VariantType.MULTI, name, value, propagate=propagate, concrete=True)
def SingleValuedVariant(
name: str, value: Union[bool, str], propagate: bool = False
) -> VariantValue:
return VariantValue(VariantType.SINGLE, name, (value,), propagate=propagate)
def BoolValuedVariant(name: str, value: bool, propagate: bool = False) -> VariantValue:
return VariantValue(VariantType.BOOL, name, (value,), propagate=propagate)
# The class below inherit from Sequence to disguise as a tuple and comply # The class below inherit from Sequence to disguise as a tuple and comply
# with the semantic expected by the 'values' argument of the variant directive # with the semantic expected by the 'values' argument of the variant directive
@ -714,7 +683,7 @@ def __lt__(self, other):
def prevalidate_variant_value( def prevalidate_variant_value(
pkg_cls: "Type[spack.package_base.PackageBase]", pkg_cls: "Type[spack.package_base.PackageBase]",
variant: VariantBase, variant: VariantValue,
spec: Optional["spack.spec.Spec"] = None, spec: Optional["spack.spec.Spec"] = None,
strict: bool = False, strict: bool = False,
) -> List[Variant]: ) -> List[Variant]:
@ -735,8 +704,8 @@ def prevalidate_variant_value(
list of variant definitions that will accept the given value. List will be empty list of variant definitions that will accept the given value. List will be empty
only if the variant is a reserved variant. only if the variant is a reserved variant.
""" """
# don't validate wildcards or variants with reserved names # do not validate non-user variants or optional variants
if variant.value == ("*",) or variant.name in RESERVED_NAMES or variant.propagate: if variant.name in RESERVED_NAMES or variant.propagate:
return [] return []
# raise if there is no definition at all # raise if there is no definition at all
@ -819,17 +788,13 @@ class MultipleValuesInExclusiveVariantError(spack.error.SpecError, ValueError):
only one. only one.
""" """
def __init__(self, variant: VariantBase, pkg_name: Optional[str] = None): def __init__(self, variant: VariantValue, pkg_name: Optional[str] = None):
pkg_info = "" if pkg_name is None else f" in package '{pkg_name}'" pkg_info = "" if pkg_name is None else f" in package '{pkg_name}'"
msg = f"multiple values are not allowed for variant '{variant.name}'{pkg_info}" msg = f"multiple values are not allowed for variant '{variant.name}'{pkg_info}"
super().__init__(msg.format(variant, pkg_info)) super().__init__(msg.format(variant, pkg_info))
class InvalidVariantValueCombinationError(spack.error.SpecError):
"""Raised when a variant has values '*' or 'none' with other values."""
class InvalidVariantValueError(spack.error.SpecError): class InvalidVariantValueError(spack.error.SpecError):
"""Raised when variants have invalid values.""" """Raised when variants have invalid values."""