variants: fix narrowing multi -> single -> bool (#49880)

* `x=*` constrained by `+x` now produces a boolean valued variant instead of a multi-valued variant.
   
* Values are now always stored as a tuple internally, whether bool, single or multi-valued. 

* Value assignment has a stricter api to prevent ambiguity / type issues related to 
   `variant.value = "x"` / `variant.value = ["x"]` / `variant.value = ("x",)`. It's now `variant.set("x", ...)` for 
   single and multi-valued variants.

* The `_original_value` prop is dropped, since it was unused.

* The wildcard `*` is no longer a possible variant value in any type of variant, since the *parser*
   deals with it and creates a variant with no values.
This commit is contained in:
Harmen Stoppels 2025-04-16 09:44:38 +02:00 committed by GitHub
parent 1dc9bac745
commit 883bbf3826
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 293 additions and 359 deletions

View File

@ -1755,15 +1755,17 @@ def define_variant(
pkg_fact(fn.variant_condition(name, vid, cond_id))
# record type so we can construct the variant when we read it back in
self.gen.fact(fn.variant_type(vid, variant_def.variant_type.value))
self.gen.fact(fn.variant_type(vid, variant_def.variant_type.string))
if variant_def.sticky:
pkg_fact(fn.variant_sticky(vid))
# define defaults for this variant definition
defaults = variant_def.make_default().value if variant_def.multi else [variant_def.default]
for val in sorted(defaults):
pkg_fact(fn.variant_default_value_from_package_py(vid, val))
if variant_def.multi:
for val in sorted(variant_def.make_default().values):
pkg_fact(fn.variant_default_value_from_package_py(vid, val))
else:
pkg_fact(fn.variant_default_value_from_package_py(vid, variant_def.default))
# define possible values for this variant definition
values = variant_def.values
@ -1791,7 +1793,9 @@ def define_variant(
# make a spec indicating whether the variant has this conditional value
variant_has_value = spack.spec.Spec()
variant_has_value.variants[name] = vt.VariantBase(name, value.value)
variant_has_value.variants[name] = vt.VariantValue(
vt.VariantType.MULTI, name, (value.value,)
)
if value.when:
# the conditional value is always "possible", but it imposes its when condition as
@ -2373,7 +2377,7 @@ def preferred_variants(self, pkg_name):
)
continue
for value in variant.value_as_tuple:
for value in variant.values:
for variant_def in variant_defs:
self.variant_values_from_specs.add((pkg_name, id(variant_def), value))
self.gen.fact(
@ -2491,7 +2495,7 @@ def _spec_clauses(
if variant.value == ("*",):
continue
for value in variant.value_as_tuple:
for value in variant.values:
# ensure that the value *can* be valid for the spec
if spec.name and not spec.concrete and not spack.repo.PATH.is_virtual(spec.name):
variant_defs = vt.prevalidate_variant_value(
@ -3828,13 +3832,13 @@ def node_os(self, node, os):
def node_target(self, node, target):
self._arch(node).target = target
def variant_selected(self, node, name, value, variant_type, variant_id):
def variant_selected(self, node, name: str, value: str, variant_type: str, variant_id):
spec = self._specs[node]
variant = spec.variants.get(name)
if not variant:
spec.variants[name] = vt.VariantType(variant_type).variant_class(name, value)
spec.variants[name] = vt.VariantValue.from_concretizer(name, value, variant_type)
else:
assert variant_type == vt.VariantType.MULTI.value, (
assert variant_type == "multi", (
f"Can't have multiple values for single-valued variant: "
f"{node}, {name}, {value}, {variant_type}, {variant_id}"
)
@ -4213,10 +4217,10 @@ def _inject_patches_variant(root: spack.spec.Spec) -> None:
continue
patches = list(spec_to_patches[id(spec)])
variant: vt.MultiValuedVariant = spec.variants.setdefault(
variant: vt.VariantValue = spec.variants.setdefault(
"patches", vt.MultiValuedVariant("patches", ())
)
variant.value = tuple(p.sha256 for p in patches)
variant.set(*(p.sha256 for p in patches))
# FIXME: Monkey patches variant to store patches order
ordered_hashes = [(*p.ordering_key, p.sha256) for p in patches if p.ordering_key]
ordered_hashes.sort()

View File

@ -1698,7 +1698,9 @@ def _dependencies_dict(self, depflag: dt.DepFlag = dt.ALL):
result[key] = list(group)
return result
def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> None:
def _add_flag(
self, name: str, value: Union[str, bool], propagate: bool, concrete: bool
) -> None:
"""Called by the parser to add a known flag"""
if propagate and name in vt.RESERVED_NAMES:
@ -1708,6 +1710,7 @@ def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> N
valid_flags = FlagMap.valid_compiler_flags()
if name == "arch" or name == "architecture":
assert type(value) is str, "architecture have a string value"
parts = tuple(value.split("-"))
plat, os, tgt = parts if len(parts) == 3 else (None, None, value)
self._set_architecture(platform=plat, os=os, target=tgt)
@ -1721,17 +1724,15 @@ def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> N
self.namespace = value
elif name in valid_flags:
assert self.compiler_flags is not None
assert type(value) is str, f"{name} must have a string value"
flags_and_propagation = spack.compilers.flags.tokenize_flags(value, propagate)
flag_group = " ".join(x for (x, y) in flags_and_propagation)
for flag, propagation in flags_and_propagation:
self.compiler_flags.add_flag(name, flag, propagation, flag_group)
else:
if str(value).upper() == "TRUE" or str(value).upper() == "FALSE":
self.variants[name] = vt.BoolValuedVariant(name, value, propagate)
elif concrete:
self.variants[name] = vt.MultiValuedVariant(name, value, propagate)
else:
self.variants[name] = vt.VariantBase(name, value, propagate)
self.variants[name] = vt.VariantValue.from_string_or_bool(
name, value, propagate=propagate, concrete=concrete
)
def _set_architecture(self, **kwargs):
"""Called by the parser to set the architecture."""
@ -4481,7 +4482,7 @@ def __init__(self, spec: Spec):
def __setitem__(self, name, vspec):
# Raise a TypeError if vspec is not of the right type
if not isinstance(vspec, vt.VariantBase):
if not isinstance(vspec, vt.VariantValue):
raise TypeError(
"VariantMap accepts only values of variant types "
f"[got {type(vspec).__name__} instead]"
@ -4621,7 +4622,7 @@ def __str__(self):
bool_keys = []
kv_keys = []
for key in sorted_keys:
if isinstance(self[key].value, bool):
if self[key].type == vt.VariantType.BOOL:
bool_keys.append(key)
else:
kv_keys.append(key)
@ -4654,7 +4655,8 @@ def substitute_abstract_variants(spec: Spec):
unknown = []
for name, v in spec.variants.items():
if name == "dev_path":
spec.variants.substitute(vt.SingleValuedVariant(name, v._original_value))
v.type = vt.VariantType.SINGLE
v.concrete = True
continue
elif name in vt.RESERVED_NAMES:
continue
@ -4677,7 +4679,7 @@ def substitute_abstract_variants(spec: Spec):
if rest:
continue
new_variant = pkg_variant.make_variant(v._original_value)
new_variant = pkg_variant.make_variant(*v.values)
pkg_variant.validate_or_raise(new_variant, spec.name)
spec.variants.substitute(new_variant)
@ -4803,7 +4805,7 @@ def from_node_dict(cls, node):
for val in values:
spec.compiler_flags.add_flag(name, val, propagate)
else:
spec.variants[name] = vt.MultiValuedVariant.from_node_dict(
spec.variants[name] = vt.VariantValue.from_node_dict(
name, values, propagate=propagate, abstract=name in abstract_variants
)
@ -4829,7 +4831,7 @@ def from_node_dict(cls, node):
patches = node["patches"]
if len(patches) > 0:
mvar = spec.variants.setdefault("patches", vt.MultiValuedVariant("patches", ()))
mvar.value = patches
mvar.set(*patches)
# FIXME: Monkey patches mvar to store patches order
mvar._patches_in_order_of_appearance = patches

View File

@ -62,7 +62,7 @@
import sys
import traceback
import warnings
from typing import Iterator, List, Optional, Tuple
from typing import Iterator, List, Optional, Tuple, Union
from llnl.util.tty import color
@ -369,7 +369,7 @@ def raise_parsing_error(string: str, cause: Optional[Exception] = None):
"""Raise a spec parsing error with token context."""
raise SpecParsingError(string, self.ctx.current_token, self.literal_str) from cause
def add_flag(name: str, value: str, propagate: bool, concrete: bool):
def add_flag(name: str, value: Union[str, bool], propagate: bool, concrete: bool):
"""Wrapper around ``Spec._add_flag()`` that adds parser context to errors raised."""
try:
initial_spec._add_flag(name, value, propagate, concrete)

View File

@ -973,13 +973,10 @@ def test_spec_formatting_bad_formats(self, default_mock_concretization, fmt_str)
with pytest.raises(SpecFormatStringError):
spec.format(fmt_str)
def test_combination_of_wildcard_or_none(self):
# Test that using 'none' and another value raises
with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot be combined"):
Spec("multivalue-variant foo=none,bar")
# Test that using wildcard and another value raises
with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot be combined"):
def test_wildcard_is_invalid_variant_value(self):
"""The spec string x=* is parsed as a multi-valued variant with values the empty set.
That excludes * as a literal variant value."""
with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot use reserved value"):
Spec("multivalue-variant foo=*,bar")
def test_errors_in_variant_directive(self):

View File

@ -21,7 +21,7 @@
SingleValuedVariant,
UnsatisfiableVariantSpecError,
Variant,
VariantBase,
VariantValue,
disjoint_sets,
)
@ -29,53 +29,37 @@
class TestMultiValuedVariant:
def test_initialization(self):
# Basic properties
a = MultiValuedVariant("foo", "bar,baz")
assert repr(a) == "MultiValuedVariant('foo', 'bar,baz')"
a = MultiValuedVariant("foo", ("bar", "baz"))
assert str(a) == "foo:=bar,baz"
assert a.values == ("bar", "baz")
assert a.value == ("bar", "baz")
assert "bar" in a
assert "baz" in a
assert eval(repr(a)) == a
# Spaces are trimmed
b = MultiValuedVariant("foo", "bar, baz")
assert repr(b) == "MultiValuedVariant('foo', 'bar, baz')"
assert str(b) == "foo:=bar,baz"
assert b.value == ("bar", "baz")
assert "bar" in b
assert "baz" in b
assert a == b
assert hash(a) == hash(b)
assert eval(repr(b)) == a
# Order is not important
c = MultiValuedVariant("foo", "baz, bar")
assert repr(c) == "MultiValuedVariant('foo', 'baz, bar')"
c = MultiValuedVariant("foo", ("baz", "bar"))
assert str(c) == "foo:=bar,baz"
assert c.value == ("bar", "baz")
assert c.values == ("bar", "baz")
assert "bar" in c
assert "baz" in c
assert a == c
assert hash(a) == hash(c)
assert eval(repr(c)) == a
# Check the copy
d = a.copy()
assert repr(a) == repr(d)
assert str(a) == str(d)
assert d.value == ("bar", "baz")
assert d.values == ("bar", "baz")
assert "bar" in d
assert "baz" in d
assert a == d
assert a is not d
assert hash(a) == hash(d)
assert eval(repr(d)) == a
def test_satisfies(self):
a = MultiValuedVariant("foo", "bar,baz")
b = MultiValuedVariant("foo", "bar")
c = MultiValuedVariant("fee", "bar,baz")
d = MultiValuedVariant("foo", "True")
a = MultiValuedVariant("foo", ("bar", "baz"))
b = MultiValuedVariant("foo", ("bar",))
c = MultiValuedVariant("fee", ("bar", "baz"))
d = MultiValuedVariant("foo", (True,))
# concrete, different values do not satisfy each other
assert not a.satisfies(b) and not b.satisfies(a)
@ -85,21 +69,19 @@ def test_satisfies(self):
# eachother
b_sv = SingleValuedVariant("foo", "bar")
assert b.satisfies(b_sv) and b_sv.satisfies(b)
d_sv = SingleValuedVariant("foo", "True")
d_sv = SingleValuedVariant("foo", True)
assert d.satisfies(d_sv) and d_sv.satisfies(d)
almost_d_bv = SingleValuedVariant("foo", "true")
assert not d.satisfies(almost_d_bv)
almost_d_bv = SingleValuedVariant("foo", True)
assert d.satisfies(almost_d_bv)
# BoolValuedVariant actually stores the value as a boolean, whereas with MV and SV the
# value is string "True".
d_bv = BoolValuedVariant("foo", "True")
assert not d.satisfies(d_bv) and not d_bv.satisfies(d)
d_bv = BoolValuedVariant("foo", True)
assert d.satisfies(d_bv) and d_bv.satisfies(d)
def test_intersects(self):
a = MultiValuedVariant("foo", "bar,baz")
b = MultiValuedVariant("foo", "True")
c = MultiValuedVariant("fee", "bar,baz")
d = MultiValuedVariant("foo", "bar,barbaz")
a = MultiValuedVariant("foo", ("bar", "baz"))
b = MultiValuedVariant("foo", (True,))
c = MultiValuedVariant("fee", ("bar", "baz"))
d = MultiValuedVariant("foo", ("bar", "barbaz"))
# concrete, different values do not intersect.
assert not a.intersects(b) and not b.intersects(a)
@ -110,47 +92,45 @@ def test_intersects(self):
assert not c.intersects(d) and not d.intersects(c)
# SV and MV intersect if they have the same concrete value.
b_sv = SingleValuedVariant("foo", "True")
b_sv = SingleValuedVariant("foo", True)
assert b.intersects(b_sv)
assert not c.intersects(b_sv)
# BoolValuedVariant stores a bool, which is not the same as the string "True" in MV.
b_bv = BoolValuedVariant("foo", "True")
assert not b.intersects(b_bv)
# BoolValuedVariant intersects if the value is the same
b_bv = BoolValuedVariant("foo", True)
assert b.intersects(b_bv)
assert not c.intersects(b_bv)
def test_constrain(self):
# Concrete values cannot be constrained
a = MultiValuedVariant("foo", "bar,baz")
b = MultiValuedVariant("foo", "bar")
a = MultiValuedVariant("foo", ("bar", "baz"))
b = MultiValuedVariant("foo", ("bar",))
with pytest.raises(UnsatisfiableVariantSpecError):
a.constrain(b)
with pytest.raises(UnsatisfiableVariantSpecError):
b.constrain(a)
# Try to constrain on the same value
a = MultiValuedVariant("foo", "bar,baz")
a = MultiValuedVariant("foo", ("bar", "baz"))
b = a.copy()
assert not a.constrain(b)
assert a == b == MultiValuedVariant("foo", "bar,baz")
assert a == b == MultiValuedVariant("foo", ("bar", "baz"))
# Try to constrain on a different name
a = MultiValuedVariant("foo", "bar,baz")
b = MultiValuedVariant("fee", "bar")
a = MultiValuedVariant("foo", ("bar", "baz"))
b = MultiValuedVariant("fee", ("bar",))
with pytest.raises(UnsatisfiableVariantSpecError):
a.constrain(b)
def test_yaml_entry(self):
a = MultiValuedVariant("foo", "bar,baz,barbaz")
b = MultiValuedVariant("foo", "bar, baz, barbaz")
expected = ("foo", sorted(["bar", "baz", "barbaz"]))
a = MultiValuedVariant("foo", ("bar", "baz", "barbaz"))
expected = ("foo", sorted(("bar", "baz", "barbaz")))
assert a.yaml_entry() == expected
assert b.yaml_entry() == expected
a = MultiValuedVariant("foo", "bar")
a = MultiValuedVariant("foo", ("bar",))
expected = ("foo", sorted(["bar"]))
assert a.yaml_entry() == expected
@ -160,26 +140,20 @@ class TestSingleValuedVariant:
def test_initialization(self):
# Basic properties
a = SingleValuedVariant("foo", "bar")
assert repr(a) == "SingleValuedVariant('foo', 'bar')"
assert str(a) == "foo=bar"
assert a.values == ("bar",)
assert a.value == "bar"
assert "bar" in a
assert eval(repr(a)) == a
# Raise if multiple values are passed
with pytest.raises(ValueError):
SingleValuedVariant("foo", "bar, baz")
# Check the copy
b = a.copy()
assert repr(a) == repr(b)
assert str(a) == str(b)
assert b.values == ("bar",)
assert b.value == "bar"
assert "bar" in b
assert a == b
assert a is not b
assert hash(a) == hash(b)
assert eval(repr(b)) == a
def test_satisfies(self):
a = SingleValuedVariant("foo", "bar")
@ -247,54 +221,37 @@ def test_yaml_entry(self):
class TestBoolValuedVariant:
def test_initialization(self):
# Basic properties - True value
for v in (True, "True", "TRUE", "TrUe"):
a = BoolValuedVariant("foo", v)
assert repr(a) == "BoolValuedVariant('foo', {0})".format(repr(v))
assert str(a) == "+foo"
assert a.value is True
assert True in a
assert eval(repr(a)) == a
a = BoolValuedVariant("foo", True)
assert str(a) == "+foo"
assert a.value is True
assert a.values == (True,)
assert True in a
# Copy - True value
b = a.copy()
assert repr(a) == repr(b)
assert str(a) == str(b)
assert b.value is True
assert b.values == (True,)
assert True in b
assert a == b
assert a is not b
assert hash(a) == hash(b)
assert eval(repr(b)) == a
# Basic properties - False value
for v in (False, "False", "FALSE", "FaLsE"):
a = BoolValuedVariant("foo", v)
assert repr(a) == "BoolValuedVariant('foo', {0})".format(repr(v))
assert str(a) == "~foo"
assert a.value is False
assert False in a
assert eval(repr(a)) == a
# Copy - True value
# Copy - False value
a = BoolValuedVariant("foo", False)
b = a.copy()
assert repr(a) == repr(b)
assert str(a) == str(b)
assert b.value is False
assert b.values == (False,)
assert False in b
assert a == b
assert a is not b
assert eval(repr(b)) == a
# Invalid values
for v in ("bar", "bar,baz"):
with pytest.raises(ValueError):
BoolValuedVariant("foo", v)
def test_satisfies(self):
a = BoolValuedVariant("foo", True)
b = BoolValuedVariant("foo", False)
c = BoolValuedVariant("fee", False)
d = BoolValuedVariant("foo", "True")
d = BoolValuedVariant("foo", True)
# concrete, different values do not satisfy each other
assert not a.satisfies(b) and not b.satisfies(a)
@ -325,7 +282,7 @@ def test_intersects(self):
a = BoolValuedVariant("foo", True)
b = BoolValuedVariant("fee", True)
c = BoolValuedVariant("foo", False)
d = BoolValuedVariant("foo", "True")
d = BoolValuedVariant("foo", True)
# concrete, different values do not intersect each other
assert not a.intersects(b) and not b.intersects(a)
@ -347,7 +304,7 @@ def test_intersects(self):
def test_constrain(self):
# Try to constrain on a value equal to self
a = BoolValuedVariant("foo", "True")
a = BoolValuedVariant("foo", True)
b = BoolValuedVariant("foo", True)
assert not a.constrain(b)
@ -375,24 +332,24 @@ def test_constrain(self):
assert a == BoolValuedVariant("foo", True)
def test_yaml_entry(self):
a = BoolValuedVariant("foo", "True")
a = BoolValuedVariant("foo", True)
expected = ("foo", True)
assert a.yaml_entry() == expected
a = BoolValuedVariant("foo", "False")
a = BoolValuedVariant("foo", False)
expected = ("foo", False)
assert a.yaml_entry() == expected
def test_from_node_dict():
a = MultiValuedVariant.from_node_dict("foo", ["bar"])
assert type(a) is MultiValuedVariant
a = VariantValue.from_node_dict("foo", ["bar"])
assert a.type == spack.variant.VariantType.MULTI
a = MultiValuedVariant.from_node_dict("foo", "bar")
assert type(a) is SingleValuedVariant
a = VariantValue.from_node_dict("foo", "bar")
assert a.type == spack.variant.VariantType.SINGLE
a = MultiValuedVariant.from_node_dict("foo", "true")
assert type(a) is BoolValuedVariant
a = VariantValue.from_node_dict("foo", "true")
assert a.type == spack.variant.VariantType.BOOL
class TestVariant:
@ -406,7 +363,7 @@ def test_validation(self):
# Multiple values are not allowed
with pytest.raises(MultipleValuesInExclusiveVariantError):
vspec.value = "bar,baz"
vspec.set("bar", "baz")
# Inconsistent vspec
vspec.name = "FOO"
@ -415,10 +372,10 @@ def test_validation(self):
# Valid multi-value vspec
a.multi = True
vspec = a.make_variant("bar,baz")
vspec = a.make_variant("bar", "baz")
a.validate_or_raise(vspec, "test-package")
# Add an invalid value
vspec.value = "bar,baz,barbaz"
vspec.set("bar", "baz", "barbaz")
with pytest.raises(InvalidVariantValueError):
a.validate_or_raise(vspec, "test-package")
@ -429,12 +386,12 @@ def validator(x):
except ValueError:
return False
a = Variant("foo", default=1024, description="", values=validator, multi=False)
a = Variant("foo", default="1024", description="", values=validator, multi=False)
vspec = a.make_default()
a.validate_or_raise(vspec, "test-package")
vspec.value = 2056
vspec.set("2056")
a.validate_or_raise(vspec, "test-package")
vspec.value = "foo"
vspec.set("foo")
with pytest.raises(InvalidVariantValueError):
a.validate_or_raise(vspec, "test-package")
@ -464,9 +421,9 @@ def test_invalid_values(self) -> None:
a["foo"] = 2
# Duplicate variant
a["foo"] = MultiValuedVariant("foo", "bar,baz")
a["foo"] = MultiValuedVariant("foo", ("bar", "baz"))
with pytest.raises(DuplicateVariantError):
a["foo"] = MultiValuedVariant("foo", "bar")
a["foo"] = MultiValuedVariant("foo", ("bar",))
with pytest.raises(DuplicateVariantError):
a["foo"] = SingleValuedVariant("foo", "bar")
@ -476,7 +433,7 @@ def test_invalid_values(self) -> None:
# Non matching names between key and vspec.name
with pytest.raises(KeyError):
a["bar"] = MultiValuedVariant("foo", "bar")
a["bar"] = MultiValuedVariant("foo", ("bar",))
def test_set_item(self) -> None:
# Check that all the three types of variants are accepted
@ -484,7 +441,7 @@ def test_set_item(self) -> None:
a["foo"] = BoolValuedVariant("foo", True)
a["bar"] = SingleValuedVariant("bar", "baz")
a["foobar"] = MultiValuedVariant("foobar", "a, b, c, d, e")
a["foobar"] = MultiValuedVariant("foobar", ("a", "b", "c", "d", "e"))
def test_substitute(self) -> None:
# Check substitution of a key that exists
@ -500,13 +457,13 @@ def test_substitute(self) -> None:
def test_satisfies_and_constrain(self) -> None:
# foo=bar foobar=fee feebar=foo
a = VariantMap(Spec())
a["foo"] = MultiValuedVariant("foo", "bar")
a["foo"] = MultiValuedVariant("foo", ("bar",))
a["foobar"] = SingleValuedVariant("foobar", "fee")
a["feebar"] = SingleValuedVariant("feebar", "foo")
# foo=bar,baz foobar=fee shared=True
b = VariantMap(Spec())
b["foo"] = MultiValuedVariant("foo", "bar, baz")
b["foo"] = MultiValuedVariant("foo", ("bar", "baz"))
b["foobar"] = SingleValuedVariant("foobar", "fee")
b["shared"] = BoolValuedVariant("shared", True)
@ -516,7 +473,7 @@ def test_satisfies_and_constrain(self) -> None:
# foo=bar,baz foobar=fee feebar=foo shared=True
c = VariantMap(Spec())
c["foo"] = MultiValuedVariant("foo", "bar, baz")
c["foo"] = MultiValuedVariant("foo", ("bar", "baz"))
c["foobar"] = SingleValuedVariant("foobar", "fee")
c["feebar"] = SingleValuedVariant("feebar", "foo")
c["shared"] = BoolValuedVariant("shared", True)
@ -529,14 +486,14 @@ def test_copy(self) -> None:
a = VariantMap(Spec())
a["foo"] = BoolValuedVariant("foo", True)
a["bar"] = SingleValuedVariant("bar", "baz")
a["foobar"] = MultiValuedVariant("foobar", "a, b, c, d, e")
a["foobar"] = MultiValuedVariant("foobar", ("a", "b", "c", "d", "e"))
c = a.copy()
assert a == c
def test_str(self) -> None:
c = VariantMap(Spec())
c["foo"] = MultiValuedVariant("foo", "bar, baz")
c["foo"] = MultiValuedVariant("foo", ("bar", "baz"))
c["foobar"] = SingleValuedVariant("foobar", "fee")
c["feebar"] = SingleValuedVariant("feebar", "foo")
c["shared"] = BoolValuedVariant("shared", True)
@ -635,10 +592,10 @@ def test_wild_card_valued_variants_equivalent_to_str():
several_arbitrary_values = ("doe", "re", "mi")
# "*" case
wild_output = wild_var.make_variant(several_arbitrary_values)
wild_output = wild_var.make_variant(*several_arbitrary_values)
wild_var.validate_or_raise(wild_output, "test-package")
# str case
str_output = str_var.make_variant(several_arbitrary_values)
str_output = str_var.make_variant(*several_arbitrary_values)
str_var.validate_or_raise(str_output, "test-package")
# equivalence each instance already validated
assert str_output.value == wild_output.value
@ -760,21 +717,21 @@ def test_concretize_variant_default_with_multiple_defs(
"spec,variant_name,narrowed_type",
[
# dev_path is a special case
("foo dev_path=/path/to/source", "dev_path", SingleValuedVariant),
("foo dev_path=/path/to/source", "dev_path", spack.variant.VariantType.SINGLE),
# reserved name: won't be touched
("foo patches=2349dc44", "patches", VariantBase),
("foo patches=2349dc44", "patches", spack.variant.VariantType.MULTI),
# simple case -- one definition applies
("variant-values@1.0 v=foo", "v", SingleValuedVariant),
("variant-values@1.0 v=foo", "v", spack.variant.VariantType.SINGLE),
# simple, but with bool valued variant
("pkg-a bvv=true", "bvv", BoolValuedVariant),
("pkg-a bvv=true", "bvv", spack.variant.VariantType.BOOL),
# takes the second definition, which overrides the single-valued one
("variant-values@2.0 v=bar", "v", MultiValuedVariant),
("variant-values@2.0 v=bar", "v", spack.variant.VariantType.MULTI),
],
)
def test_substitute_abstract_variants_narrowing(mock_packages, spec, variant_name, narrowed_type):
spec = Spec(spec)
spack.spec.substitute_abstract_variants(spec)
assert type(spec.variants[variant_name]) is narrowed_type
assert spec.variants[variant_name].type == narrowed_type
def test_substitute_abstract_variants_failure(mock_packages):
@ -920,3 +877,12 @@ def test_patches_variant():
assert not Spec("patches:=abcdef").satisfies("patches:=ab")
assert not Spec("patches:=abcdef,xyz").satisfies("patches:=abc,xyz")
assert not Spec("patches:=abcdef").satisfies("patches:=abcdefghi")
def test_constrain_narrowing():
s = Spec("foo=*")
assert s.variants["foo"].type == spack.variant.VariantType.MULTI
assert not s.variants["foo"].concrete
s.constrain("+foo")
assert s.variants["foo"].type == spack.variant.VariantType.BOOL
assert s.variants["foo"].concrete

View File

@ -10,7 +10,6 @@
import functools
import inspect
import itertools
import re
from typing import Any, Callable, Collection, Iterable, List, Optional, Tuple, Type, Union
import llnl.util.lang as lang
@ -33,24 +32,22 @@
"target",
}
special_variant_values = [None, "none", "*"]
class VariantType(enum.Enum):
class VariantType(enum.IntEnum):
"""Enum representing the three concrete variant types."""
MULTI = "multi"
BOOL = "bool"
SINGLE = "single"
BOOL = 1
SINGLE = 2
MULTI = 3
@property
def variant_class(self) -> Type:
if self is self.MULTI:
return MultiValuedVariant
elif self is self.BOOL:
return BoolValuedVariant
else:
return SingleValuedVariant
def string(self) -> str:
"""Convert the variant type to a string."""
if self == VariantType.BOOL:
return "bool"
elif self == VariantType.SINGLE:
return "single"
return "multi"
class Variant:
@ -134,7 +131,7 @@ def isa_type(v):
self.sticky = sticky
self.precedence = precedence
def validate_or_raise(self, vspec: "VariantBase", pkg_name: str):
def validate_or_raise(self, vspec: "VariantValue", pkg_name: str):
"""Validate a variant spec against this package variant. Raises an
exception if any error is found.
@ -156,7 +153,7 @@ def validate_or_raise(self, vspec: "VariantBase", pkg_name: str):
raise InconsistentValidationError(vspec, self)
# If the value is exclusive there must be at most one
value = vspec.value_as_tuple
value = vspec.values
if not self.multi and len(value) != 1:
raise MultipleValuesInExclusiveVariantError(vspec, pkg_name)
@ -191,27 +188,15 @@ def allowed_values(self):
v = docstring if docstring else ""
return v
def make_default(self):
"""Factory that creates a variant holding the default value.
def make_default(self) -> "VariantValue":
"""Factory that creates a variant holding the default value(s)."""
variant = VariantValue.from_string_or_bool(self.name, self.default)
variant.type = self.variant_type
return variant
Returns:
MultiValuedVariant or SingleValuedVariant or BoolValuedVariant:
instance of the proper variant
"""
return self.make_variant(self.default)
def make_variant(self, value: Union[str, bool]) -> "VariantBase":
"""Factory that creates a variant holding the value passed as
a parameter.
Args:
value: value that will be hold by the variant
Returns:
MultiValuedVariant or SingleValuedVariant or BoolValuedVariant:
instance of the proper variant
"""
return self.variant_type.variant_class(self.name, value)
def make_variant(self, *value: Union[str, bool]) -> "VariantValue":
"""Factory that creates a variant holding the value(s) passed."""
return VariantValue(self.variant_type, self.name, value)
@property
def variant_type(self) -> VariantType:
@ -254,116 +239,148 @@ def _flatten(values) -> Collection:
#: Type for value of a variant
ValueType = Union[str, bool, Tuple[Union[str, bool], ...]]
ValueType = Tuple[Union[bool, str], ...]
#: Type of variant value when output for JSON, YAML, etc.
SerializedValueType = Union[str, bool, List[Union[str, bool]]]
@lang.lazy_lexicographic_ordering
class VariantBase:
"""A BaseVariant corresponds to a spec string of the form ``foo=bar`` or ``foo=bar,baz``.
It is a constraint on the spec and abstract in the sense that it must have **at least** these
values -- concretization may add more values."""
class VariantValue:
"""A VariantValue is a key-value pair that represents a variant. It can have zero or more
values. Values have set semantics, so they are unordered and unique. The variant type can
be narrowed from multi to single to boolean, this limits the number of values that can be
stored in the variant. Multi-valued variants can either be concrete or abstract: abstract
means that the variant takes at least the values specified, but may take more when concretized.
Concrete means that the variant takes exactly the values specified. Lastly, a variant can be
marked as propagating, which means that it should be propagated to dependencies."""
name: str
propagate: bool
_value: ValueType
_original_value: Any
concrete: bool
type: VariantType
_values: ValueType
def __init__(self, name: str, value: ValueType, propagate: bool = False) -> None:
slots = ("name", "propagate", "concrete", "type", "_values")
def __init__(
self,
type: VariantType,
name: str,
value: ValueType,
*,
propagate: bool = False,
concrete: bool = False,
) -> None:
self.name = name
self.type = type
self.propagate = propagate
self.concrete = False
# only multi-valued variants can be abstract
self.concrete = concrete or type in (VariantType.BOOL, VariantType.SINGLE)
# Invokes property setter
self.value = value
self.set(*value)
@staticmethod
def from_node_dict(
name: str, value: Union[str, List[str]], *, propagate: bool = False, abstract: bool = False
) -> "VariantBase":
) -> "VariantValue":
"""Reconstruct a variant from a node dict."""
if isinstance(value, list):
constructor = VariantBase if abstract else MultiValuedVariant
mvar = constructor(name, (), propagate=propagate)
mvar._value = tuple(value)
mvar._original_value = mvar._value
return mvar
return VariantValue(
VariantType.MULTI, name, tuple(value), propagate=propagate, concrete=not abstract
)
# todo: is this necessary? not literal true / false in json/yaml?
elif str(value).upper() == "TRUE" or str(value).upper() == "FALSE":
return BoolValuedVariant(name, value, propagate=propagate)
return VariantValue(
VariantType.BOOL, name, (str(value).upper() == "TRUE",), propagate=propagate
)
return SingleValuedVariant(name, value, propagate=propagate)
return VariantValue(VariantType.SINGLE, name, (value,), propagate=propagate)
@staticmethod
def from_string_or_bool(
name: str, value: Union[str, bool], *, propagate: bool = False, concrete: bool = False
) -> "VariantValue":
if value is True or value is False:
return VariantValue(VariantType.BOOL, name, (value,), propagate=propagate)
elif value.upper() in ("TRUE", "FALSE"):
return VariantValue(
VariantType.BOOL, name, (value.upper() == "TRUE",), propagate=propagate
)
elif value == "*":
return VariantValue(VariantType.MULTI, name, (), propagate=propagate)
return VariantValue(
VariantType.MULTI,
name,
tuple(value.split(",")),
propagate=propagate,
concrete=concrete,
)
@staticmethod
def from_concretizer(name: str, value: str, type: str) -> "VariantValue":
"""Reconstruct a variant from concretizer output."""
if type == "bool":
return VariantValue(VariantType.BOOL, name, (value == "True",))
elif type == "multi":
return VariantValue(VariantType.MULTI, name, (value,), concrete=True)
else:
return VariantValue(VariantType.SINGLE, name, (value,))
def yaml_entry(self) -> Tuple[str, SerializedValueType]:
"""Returns a key, value tuple suitable to be an entry in a yaml dict.
"""Returns a (key, value) tuple suitable to be an entry in a yaml dict.
Returns:
tuple: (name, value_representation)
"""
return self.name, list(self.value_as_tuple)
if self.type == VariantType.MULTI:
return self.name, list(self.values)
return self.name, self.values[0]
@property
def value_as_tuple(self) -> Tuple[Union[bool, str], ...]:
"""Getter for self.value that always returns a Tuple (even for single valued variants).
This makes it easy to iterate over possible values.
"""
if isinstance(self._value, (bool, str)):
return (self._value,)
return self._value
def values(self) -> ValueType:
return self._values
@property
def value(self) -> ValueType:
"""Returns a tuple of strings containing the values stored in
the variant.
def value(self) -> Union[ValueType, bool, str]:
return self._values[0] if self.type != VariantType.MULTI else self._values
Returns:
tuple: values stored in the variant
"""
return self._value
def set(self, *value: Union[bool, str]) -> None:
"""Set the value(s) of the variant."""
if len(value) > 1:
value = tuple(sorted(set(value)))
@value.setter
def value(self, value: ValueType) -> None:
self._value_setter(value)
if self.type != VariantType.MULTI:
if len(value) != 1:
raise MultipleValuesInExclusiveVariantError(self)
unwrapped = value[0]
if self.type == VariantType.BOOL and unwrapped not in (True, False):
raise ValueError(
f"cannot set a boolean variant to a value that is not a boolean: {unwrapped}"
)
def _value_setter(self, value: ValueType) -> None:
# Store the original value
self._original_value = value
if "*" in value:
raise InvalidVariantValueError("cannot use reserved value '*'")
if value == "*":
self._value = ()
return
if not isinstance(value, (tuple, list)):
# Store a tuple of CSV string representations
# Tuple is necessary here instead of list because the
# values need to be hashed
value = tuple(re.split(r"\s*,\s*", str(value)))
for val in special_variant_values:
if val in value and len(value) > 1:
msg = "'%s' cannot be combined with other variant" % val
msg += " values."
raise InvalidVariantValueCombinationError(msg)
# With multi-value variants it is necessary
# to remove duplicates and give an order
# to a set
self._value = tuple(sorted(set(value)))
self._values = value
def _cmp_iter(self) -> Iterable:
yield self.name
yield self.propagate
yield from (str(v) for v in self.value_as_tuple)
yield self.concrete
yield from (str(v) for v in self.values)
def copy(self) -> "VariantBase":
variant = type(self)(self.name, self._original_value, self.propagate)
variant.concrete = self.concrete
return variant
def copy(self) -> "VariantValue":
return VariantValue(
self.type, self.name, self.values, propagate=self.propagate, concrete=self.concrete
)
def satisfies(self, other: "VariantBase") -> bool:
def satisfies(self, other: "VariantValue") -> bool:
"""The lhs satisfies the rhs if all possible concretizations of lhs are also
possible concretizations of rhs."""
if self.name != other.name:
@ -376,138 +393,90 @@ def satisfies(self, other: "VariantBase") -> bool:
if self.name == "patches":
return all(
isinstance(v, str)
and any(isinstance(w, str) and w.startswith(v) for w in self.value_as_tuple)
for v in other.value_as_tuple
and any(isinstance(w, str) and w.startswith(v) for w in self.values)
for v in other.values
)
return all(v in self for v in other.value_as_tuple)
return all(v in self for v in other.values)
if self.concrete:
# both concrete: they must be equal
return self.value_as_tuple == other.value_as_tuple
return self.values == other.values
return False
def intersects(self, other: "VariantBase") -> bool:
def intersects(self, other: "VariantValue") -> bool:
"""True iff there exists a concretization that satisfies both lhs and rhs."""
if self.name != other.name:
return False
if self.concrete:
if other.concrete:
return self.value_as_tuple == other.value_as_tuple
return all(v in self for v in other.value_as_tuple)
return self.values == other.values
return all(v in self for v in other.values)
if other.concrete:
return all(v in other for v in self.value_as_tuple)
return all(v in other for v in self.values)
# both abstract: the union is a valid concretization of both
return True
def constrain(self, other: "VariantBase") -> bool:
def constrain(self, other: "VariantValue") -> bool:
"""Constrain self with other if they intersect. Returns true iff self was changed."""
if not self.intersects(other):
raise UnsatisfiableVariantSpecError(self, other)
old_value = self.value
values = list(sorted({*self.value_as_tuple, *other.value_as_tuple}))
self._value_setter(",".join(str(v) for v in values))
changed = old_value != self.value
old_values = self.values
self.set(*self.values, *other.values)
changed = old_values != self.values
if self.propagate and not other.propagate:
self.propagate = False
changed = True
if not self.concrete and other.concrete:
self.concrete = True
changed = True
if self.type > other.type:
self.type = other.type
changed = True
return changed
def __contains__(self, item: Union[str, bool]) -> bool:
return item in self.value_as_tuple
def append(self, value: Union[str, bool]) -> None:
self.set(*self.values, value)
def __repr__(self) -> str:
return f"{type(self).__name__}({repr(self.name)}, {repr(self._original_value)})"
def __contains__(self, item: Union[str, bool]) -> bool:
return item in self.values
def __str__(self) -> str:
concrete = ":" if self.concrete else ""
# boolean variants are printed +foo or ~foo
if self.type == VariantType.BOOL:
sigil = "+" if self.value else "~"
if self.propagate:
sigil *= 2
return f"{sigil}{self.name}"
# concrete multi-valued foo:=bar,baz
concrete = ":" if self.type == VariantType.MULTI and self.concrete else ""
delim = "==" if self.propagate else "="
values_tuple = self.value_as_tuple
if values_tuple:
value_str = ",".join(str(v) for v in values_tuple)
else:
if not self.values:
value_str = "*"
elif self.name == "patches" and self.concrete:
value_str = ",".join(str(x)[:7] for x in self.values)
else:
value_str = ",".join(str(x) for x in self.values)
return f"{self.name}{concrete}{delim}{spack.spec_parser.quote_if_needed(value_str)}"
class MultiValuedVariant(VariantBase):
def __init__(self, name, value, propagate=False):
super().__init__(name, value, propagate)
self.concrete = True
def append(self, value: Union[str, bool]) -> None:
"""Add another value to this multi-valued variant."""
self._value = tuple(sorted((value,) + self.value_as_tuple))
self._original_value = ",".join(str(v) for v in self._value)
def __str__(self) -> str:
# Special-case patches to not print the full 64 character sha256
if self.name == "patches":
values_str = ",".join(str(x)[:7] for x in self.value_as_tuple)
else:
values_str = ",".join(str(x) for x in self.value_as_tuple)
delim = "==" if self.propagate else "="
return f"{self.name}:{delim}{spack.spec_parser.quote_if_needed(values_str)}"
def __repr__(self):
return (
f"VariantValue({self.type!r}, {self.name!r}, {self.values!r}, "
f"propagate={self.propagate!r}, concrete={self.concrete!r})"
)
class SingleValuedVariant(VariantBase):
def __init__(self, name, value, propagate=False):
super().__init__(name, value, propagate)
self.concrete = True
def _value_setter(self, value: ValueType) -> None:
# Treat the value as a multi-valued variant
super()._value_setter(value)
# Then check if there's only a single value
values = self.value_as_tuple
if len(values) != 1:
raise MultipleValuesInExclusiveVariantError(self)
self._value = values[0]
def __contains__(self, item: ValueType) -> bool:
return item == self.value
def yaml_entry(self) -> Tuple[str, SerializedValueType]:
assert isinstance(self.value, (bool, str))
return self.name, self.value
def __str__(self) -> str:
delim = "==" if self.propagate else "="
return f"{self.name}{delim}{spack.spec_parser.quote_if_needed(str(self.value))}"
def MultiValuedVariant(name: str, value: ValueType, propagate: bool = False) -> VariantValue:
return VariantValue(VariantType.MULTI, name, value, propagate=propagate, concrete=True)
class BoolValuedVariant(SingleValuedVariant):
def __init__(self, name, value, propagate=False):
super().__init__(name, value, propagate)
self.concrete = True
def SingleValuedVariant(
name: str, value: Union[bool, str], propagate: bool = False
) -> VariantValue:
return VariantValue(VariantType.SINGLE, name, (value,), propagate=propagate)
def _value_setter(self, value: ValueType) -> None:
# Check the string representation of the value and turn
# it to a boolean
if str(value).upper() == "TRUE":
self._original_value = value
self._value = True
elif str(value).upper() == "FALSE":
self._original_value = value
self._value = False
else:
raise ValueError(
f'cannot construct a BoolValuedVariant for "{self.name}" from '
"a value that does not represent a bool"
)
def __contains__(self, item: ValueType) -> bool:
return item is self.value
def __str__(self) -> str:
sigil = "+" if self.value else "~"
if self.propagate:
sigil *= 2
return f"{sigil}{self.name}"
def BoolValuedVariant(name: str, value: bool, propagate: bool = False) -> VariantValue:
return VariantValue(VariantType.BOOL, name, (value,), propagate=propagate)
# The class below inherit from Sequence to disguise as a tuple and comply
@ -714,7 +683,7 @@ def __lt__(self, other):
def prevalidate_variant_value(
pkg_cls: "Type[spack.package_base.PackageBase]",
variant: VariantBase,
variant: VariantValue,
spec: Optional["spack.spec.Spec"] = None,
strict: bool = False,
) -> List[Variant]:
@ -735,8 +704,8 @@ def prevalidate_variant_value(
list of variant definitions that will accept the given value. List will be empty
only if the variant is a reserved variant.
"""
# don't validate wildcards or variants with reserved names
if variant.value == ("*",) or variant.name in RESERVED_NAMES or variant.propagate:
# do not validate non-user variants or optional variants
if variant.name in RESERVED_NAMES or variant.propagate:
return []
# raise if there is no definition at all
@ -819,17 +788,13 @@ class MultipleValuesInExclusiveVariantError(spack.error.SpecError, ValueError):
only one.
"""
def __init__(self, variant: VariantBase, pkg_name: Optional[str] = None):
def __init__(self, variant: VariantValue, pkg_name: Optional[str] = None):
pkg_info = "" if pkg_name is None else f" in package '{pkg_name}'"
msg = f"multiple values are not allowed for variant '{variant.name}'{pkg_info}"
super().__init__(msg.format(variant, pkg_info))
class InvalidVariantValueCombinationError(spack.error.SpecError):
"""Raised when a variant has values '*' or 'none' with other values."""
class InvalidVariantValueError(spack.error.SpecError):
"""Raised when variants have invalid values."""