diff --git a/lib/spack/spack/solver/asp.py b/lib/spack/spack/solver/asp.py index 0c03e0b47e3..01d897c82e4 100644 --- a/lib/spack/spack/solver/asp.py +++ b/lib/spack/spack/solver/asp.py @@ -1755,15 +1755,17 @@ def define_variant( pkg_fact(fn.variant_condition(name, vid, cond_id)) # record type so we can construct the variant when we read it back in - self.gen.fact(fn.variant_type(vid, variant_def.variant_type.value)) + self.gen.fact(fn.variant_type(vid, variant_def.variant_type.string)) if variant_def.sticky: pkg_fact(fn.variant_sticky(vid)) # define defaults for this variant definition - defaults = variant_def.make_default().value if variant_def.multi else [variant_def.default] - for val in sorted(defaults): - pkg_fact(fn.variant_default_value_from_package_py(vid, val)) + if variant_def.multi: + for val in sorted(variant_def.make_default().values): + pkg_fact(fn.variant_default_value_from_package_py(vid, val)) + else: + pkg_fact(fn.variant_default_value_from_package_py(vid, variant_def.default)) # define possible values for this variant definition values = variant_def.values @@ -1791,7 +1793,9 @@ def define_variant( # make a spec indicating whether the variant has this conditional value variant_has_value = spack.spec.Spec() - variant_has_value.variants[name] = vt.VariantBase(name, value.value) + variant_has_value.variants[name] = vt.VariantValue( + vt.VariantType.MULTI, name, (value.value,) + ) if value.when: # the conditional value is always "possible", but it imposes its when condition as @@ -2373,7 +2377,7 @@ def preferred_variants(self, pkg_name): ) continue - for value in variant.value_as_tuple: + for value in variant.values: for variant_def in variant_defs: self.variant_values_from_specs.add((pkg_name, id(variant_def), value)) self.gen.fact( @@ -2491,7 +2495,7 @@ def _spec_clauses( if variant.value == ("*",): continue - for value in variant.value_as_tuple: + for value in variant.values: # ensure that the value *can* be valid for the spec if spec.name and not spec.concrete and not spack.repo.PATH.is_virtual(spec.name): variant_defs = vt.prevalidate_variant_value( @@ -3828,13 +3832,13 @@ def node_os(self, node, os): def node_target(self, node, target): self._arch(node).target = target - def variant_selected(self, node, name, value, variant_type, variant_id): + def variant_selected(self, node, name: str, value: str, variant_type: str, variant_id): spec = self._specs[node] variant = spec.variants.get(name) if not variant: - spec.variants[name] = vt.VariantType(variant_type).variant_class(name, value) + spec.variants[name] = vt.VariantValue.from_concretizer(name, value, variant_type) else: - assert variant_type == vt.VariantType.MULTI.value, ( + assert variant_type == "multi", ( f"Can't have multiple values for single-valued variant: " f"{node}, {name}, {value}, {variant_type}, {variant_id}" ) @@ -4213,10 +4217,10 @@ def _inject_patches_variant(root: spack.spec.Spec) -> None: continue patches = list(spec_to_patches[id(spec)]) - variant: vt.MultiValuedVariant = spec.variants.setdefault( + variant: vt.VariantValue = spec.variants.setdefault( "patches", vt.MultiValuedVariant("patches", ()) ) - variant.value = tuple(p.sha256 for p in patches) + variant.set(*(p.sha256 for p in patches)) # FIXME: Monkey patches variant to store patches order ordered_hashes = [(*p.ordering_key, p.sha256) for p in patches if p.ordering_key] ordered_hashes.sort() diff --git a/lib/spack/spack/spec.py b/lib/spack/spack/spec.py index d67e1517630..06c42a8e2fc 100644 --- a/lib/spack/spack/spec.py +++ b/lib/spack/spack/spec.py @@ -1698,7 +1698,9 @@ def _dependencies_dict(self, depflag: dt.DepFlag = dt.ALL): result[key] = list(group) return result - def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> None: + def _add_flag( + self, name: str, value: Union[str, bool], propagate: bool, concrete: bool + ) -> None: """Called by the parser to add a known flag""" if propagate and name in vt.RESERVED_NAMES: @@ -1708,6 +1710,7 @@ def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> N valid_flags = FlagMap.valid_compiler_flags() if name == "arch" or name == "architecture": + assert type(value) is str, "architecture have a string value" parts = tuple(value.split("-")) plat, os, tgt = parts if len(parts) == 3 else (None, None, value) self._set_architecture(platform=plat, os=os, target=tgt) @@ -1721,17 +1724,15 @@ def _add_flag(self, name: str, value: str, propagate: bool, concrete: bool) -> N self.namespace = value elif name in valid_flags: assert self.compiler_flags is not None + assert type(value) is str, f"{name} must have a string value" flags_and_propagation = spack.compilers.flags.tokenize_flags(value, propagate) flag_group = " ".join(x for (x, y) in flags_and_propagation) for flag, propagation in flags_and_propagation: self.compiler_flags.add_flag(name, flag, propagation, flag_group) else: - if str(value).upper() == "TRUE" or str(value).upper() == "FALSE": - self.variants[name] = vt.BoolValuedVariant(name, value, propagate) - elif concrete: - self.variants[name] = vt.MultiValuedVariant(name, value, propagate) - else: - self.variants[name] = vt.VariantBase(name, value, propagate) + self.variants[name] = vt.VariantValue.from_string_or_bool( + name, value, propagate=propagate, concrete=concrete + ) def _set_architecture(self, **kwargs): """Called by the parser to set the architecture.""" @@ -4481,7 +4482,7 @@ def __init__(self, spec: Spec): def __setitem__(self, name, vspec): # Raise a TypeError if vspec is not of the right type - if not isinstance(vspec, vt.VariantBase): + if not isinstance(vspec, vt.VariantValue): raise TypeError( "VariantMap accepts only values of variant types " f"[got {type(vspec).__name__} instead]" @@ -4621,7 +4622,7 @@ def __str__(self): bool_keys = [] kv_keys = [] for key in sorted_keys: - if isinstance(self[key].value, bool): + if self[key].type == vt.VariantType.BOOL: bool_keys.append(key) else: kv_keys.append(key) @@ -4654,7 +4655,8 @@ def substitute_abstract_variants(spec: Spec): unknown = [] for name, v in spec.variants.items(): if name == "dev_path": - spec.variants.substitute(vt.SingleValuedVariant(name, v._original_value)) + v.type = vt.VariantType.SINGLE + v.concrete = True continue elif name in vt.RESERVED_NAMES: continue @@ -4677,7 +4679,7 @@ def substitute_abstract_variants(spec: Spec): if rest: continue - new_variant = pkg_variant.make_variant(v._original_value) + new_variant = pkg_variant.make_variant(*v.values) pkg_variant.validate_or_raise(new_variant, spec.name) spec.variants.substitute(new_variant) @@ -4803,7 +4805,7 @@ def from_node_dict(cls, node): for val in values: spec.compiler_flags.add_flag(name, val, propagate) else: - spec.variants[name] = vt.MultiValuedVariant.from_node_dict( + spec.variants[name] = vt.VariantValue.from_node_dict( name, values, propagate=propagate, abstract=name in abstract_variants ) @@ -4829,7 +4831,7 @@ def from_node_dict(cls, node): patches = node["patches"] if len(patches) > 0: mvar = spec.variants.setdefault("patches", vt.MultiValuedVariant("patches", ())) - mvar.value = patches + mvar.set(*patches) # FIXME: Monkey patches mvar to store patches order mvar._patches_in_order_of_appearance = patches diff --git a/lib/spack/spack/spec_parser.py b/lib/spack/spack/spec_parser.py index 1d89cd93888..f6392c33808 100644 --- a/lib/spack/spack/spec_parser.py +++ b/lib/spack/spack/spec_parser.py @@ -62,7 +62,7 @@ import sys import traceback import warnings -from typing import Iterator, List, Optional, Tuple +from typing import Iterator, List, Optional, Tuple, Union from llnl.util.tty import color @@ -369,7 +369,7 @@ def raise_parsing_error(string: str, cause: Optional[Exception] = None): """Raise a spec parsing error with token context.""" raise SpecParsingError(string, self.ctx.current_token, self.literal_str) from cause - def add_flag(name: str, value: str, propagate: bool, concrete: bool): + def add_flag(name: str, value: Union[str, bool], propagate: bool, concrete: bool): """Wrapper around ``Spec._add_flag()`` that adds parser context to errors raised.""" try: initial_spec._add_flag(name, value, propagate, concrete) diff --git a/lib/spack/spack/test/spec_semantics.py b/lib/spack/spack/test/spec_semantics.py index 9232fcb0292..02673ffab68 100644 --- a/lib/spack/spack/test/spec_semantics.py +++ b/lib/spack/spack/test/spec_semantics.py @@ -973,13 +973,10 @@ def test_spec_formatting_bad_formats(self, default_mock_concretization, fmt_str) with pytest.raises(SpecFormatStringError): spec.format(fmt_str) - def test_combination_of_wildcard_or_none(self): - # Test that using 'none' and another value raises - with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot be combined"): - Spec("multivalue-variant foo=none,bar") - - # Test that using wildcard and another value raises - with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot be combined"): + def test_wildcard_is_invalid_variant_value(self): + """The spec string x=* is parsed as a multi-valued variant with values the empty set. + That excludes * as a literal variant value.""" + with pytest.raises(spack.spec_parser.SpecParsingError, match="cannot use reserved value"): Spec("multivalue-variant foo=*,bar") def test_errors_in_variant_directive(self): diff --git a/lib/spack/spack/test/variant.py b/lib/spack/spack/test/variant.py index edf9af681fb..d32bb35fb33 100644 --- a/lib/spack/spack/test/variant.py +++ b/lib/spack/spack/test/variant.py @@ -21,7 +21,7 @@ SingleValuedVariant, UnsatisfiableVariantSpecError, Variant, - VariantBase, + VariantValue, disjoint_sets, ) @@ -29,53 +29,37 @@ class TestMultiValuedVariant: def test_initialization(self): # Basic properties - a = MultiValuedVariant("foo", "bar,baz") - assert repr(a) == "MultiValuedVariant('foo', 'bar,baz')" + a = MultiValuedVariant("foo", ("bar", "baz")) assert str(a) == "foo:=bar,baz" + assert a.values == ("bar", "baz") assert a.value == ("bar", "baz") assert "bar" in a assert "baz" in a - assert eval(repr(a)) == a - - # Spaces are trimmed - b = MultiValuedVariant("foo", "bar, baz") - assert repr(b) == "MultiValuedVariant('foo', 'bar, baz')" - assert str(b) == "foo:=bar,baz" - assert b.value == ("bar", "baz") - assert "bar" in b - assert "baz" in b - assert a == b - assert hash(a) == hash(b) - assert eval(repr(b)) == a # Order is not important - c = MultiValuedVariant("foo", "baz, bar") - assert repr(c) == "MultiValuedVariant('foo', 'baz, bar')" + c = MultiValuedVariant("foo", ("baz", "bar")) assert str(c) == "foo:=bar,baz" - assert c.value == ("bar", "baz") + assert c.values == ("bar", "baz") assert "bar" in c assert "baz" in c assert a == c assert hash(a) == hash(c) - assert eval(repr(c)) == a # Check the copy d = a.copy() - assert repr(a) == repr(d) assert str(a) == str(d) - assert d.value == ("bar", "baz") + assert d.values == ("bar", "baz") assert "bar" in d assert "baz" in d assert a == d assert a is not d assert hash(a) == hash(d) - assert eval(repr(d)) == a def test_satisfies(self): - a = MultiValuedVariant("foo", "bar,baz") - b = MultiValuedVariant("foo", "bar") - c = MultiValuedVariant("fee", "bar,baz") - d = MultiValuedVariant("foo", "True") + a = MultiValuedVariant("foo", ("bar", "baz")) + b = MultiValuedVariant("foo", ("bar",)) + c = MultiValuedVariant("fee", ("bar", "baz")) + d = MultiValuedVariant("foo", (True,)) # concrete, different values do not satisfy each other assert not a.satisfies(b) and not b.satisfies(a) @@ -85,21 +69,19 @@ def test_satisfies(self): # eachother b_sv = SingleValuedVariant("foo", "bar") assert b.satisfies(b_sv) and b_sv.satisfies(b) - d_sv = SingleValuedVariant("foo", "True") + d_sv = SingleValuedVariant("foo", True) assert d.satisfies(d_sv) and d_sv.satisfies(d) - almost_d_bv = SingleValuedVariant("foo", "true") - assert not d.satisfies(almost_d_bv) + almost_d_bv = SingleValuedVariant("foo", True) + assert d.satisfies(almost_d_bv) - # BoolValuedVariant actually stores the value as a boolean, whereas with MV and SV the - # value is string "True". - d_bv = BoolValuedVariant("foo", "True") - assert not d.satisfies(d_bv) and not d_bv.satisfies(d) + d_bv = BoolValuedVariant("foo", True) + assert d.satisfies(d_bv) and d_bv.satisfies(d) def test_intersects(self): - a = MultiValuedVariant("foo", "bar,baz") - b = MultiValuedVariant("foo", "True") - c = MultiValuedVariant("fee", "bar,baz") - d = MultiValuedVariant("foo", "bar,barbaz") + a = MultiValuedVariant("foo", ("bar", "baz")) + b = MultiValuedVariant("foo", (True,)) + c = MultiValuedVariant("fee", ("bar", "baz")) + d = MultiValuedVariant("foo", ("bar", "barbaz")) # concrete, different values do not intersect. assert not a.intersects(b) and not b.intersects(a) @@ -110,47 +92,45 @@ def test_intersects(self): assert not c.intersects(d) and not d.intersects(c) # SV and MV intersect if they have the same concrete value. - b_sv = SingleValuedVariant("foo", "True") + b_sv = SingleValuedVariant("foo", True) assert b.intersects(b_sv) assert not c.intersects(b_sv) - # BoolValuedVariant stores a bool, which is not the same as the string "True" in MV. - b_bv = BoolValuedVariant("foo", "True") - assert not b.intersects(b_bv) + # BoolValuedVariant intersects if the value is the same + b_bv = BoolValuedVariant("foo", True) + assert b.intersects(b_bv) assert not c.intersects(b_bv) def test_constrain(self): # Concrete values cannot be constrained - a = MultiValuedVariant("foo", "bar,baz") - b = MultiValuedVariant("foo", "bar") + a = MultiValuedVariant("foo", ("bar", "baz")) + b = MultiValuedVariant("foo", ("bar",)) with pytest.raises(UnsatisfiableVariantSpecError): a.constrain(b) with pytest.raises(UnsatisfiableVariantSpecError): b.constrain(a) # Try to constrain on the same value - a = MultiValuedVariant("foo", "bar,baz") + a = MultiValuedVariant("foo", ("bar", "baz")) b = a.copy() assert not a.constrain(b) - assert a == b == MultiValuedVariant("foo", "bar,baz") + assert a == b == MultiValuedVariant("foo", ("bar", "baz")) # Try to constrain on a different name - a = MultiValuedVariant("foo", "bar,baz") - b = MultiValuedVariant("fee", "bar") + a = MultiValuedVariant("foo", ("bar", "baz")) + b = MultiValuedVariant("fee", ("bar",)) with pytest.raises(UnsatisfiableVariantSpecError): a.constrain(b) def test_yaml_entry(self): - a = MultiValuedVariant("foo", "bar,baz,barbaz") - b = MultiValuedVariant("foo", "bar, baz, barbaz") - expected = ("foo", sorted(["bar", "baz", "barbaz"])) + a = MultiValuedVariant("foo", ("bar", "baz", "barbaz")) + expected = ("foo", sorted(("bar", "baz", "barbaz"))) assert a.yaml_entry() == expected - assert b.yaml_entry() == expected - a = MultiValuedVariant("foo", "bar") + a = MultiValuedVariant("foo", ("bar",)) expected = ("foo", sorted(["bar"])) assert a.yaml_entry() == expected @@ -160,26 +140,20 @@ class TestSingleValuedVariant: def test_initialization(self): # Basic properties a = SingleValuedVariant("foo", "bar") - assert repr(a) == "SingleValuedVariant('foo', 'bar')" assert str(a) == "foo=bar" + assert a.values == ("bar",) assert a.value == "bar" assert "bar" in a - assert eval(repr(a)) == a - - # Raise if multiple values are passed - with pytest.raises(ValueError): - SingleValuedVariant("foo", "bar, baz") # Check the copy b = a.copy() - assert repr(a) == repr(b) assert str(a) == str(b) + assert b.values == ("bar",) assert b.value == "bar" assert "bar" in b assert a == b assert a is not b assert hash(a) == hash(b) - assert eval(repr(b)) == a def test_satisfies(self): a = SingleValuedVariant("foo", "bar") @@ -247,54 +221,37 @@ def test_yaml_entry(self): class TestBoolValuedVariant: def test_initialization(self): # Basic properties - True value - for v in (True, "True", "TRUE", "TrUe"): - a = BoolValuedVariant("foo", v) - assert repr(a) == "BoolValuedVariant('foo', {0})".format(repr(v)) - assert str(a) == "+foo" - assert a.value is True - assert True in a - assert eval(repr(a)) == a + a = BoolValuedVariant("foo", True) + assert str(a) == "+foo" + assert a.value is True + assert a.values == (True,) + assert True in a # Copy - True value b = a.copy() - assert repr(a) == repr(b) assert str(a) == str(b) assert b.value is True + assert b.values == (True,) assert True in b assert a == b assert a is not b assert hash(a) == hash(b) - assert eval(repr(b)) == a - # Basic properties - False value - for v in (False, "False", "FALSE", "FaLsE"): - a = BoolValuedVariant("foo", v) - assert repr(a) == "BoolValuedVariant('foo', {0})".format(repr(v)) - assert str(a) == "~foo" - assert a.value is False - assert False in a - assert eval(repr(a)) == a - - # Copy - True value + # Copy - False value + a = BoolValuedVariant("foo", False) b = a.copy() - assert repr(a) == repr(b) assert str(a) == str(b) assert b.value is False + assert b.values == (False,) assert False in b assert a == b assert a is not b - assert eval(repr(b)) == a - - # Invalid values - for v in ("bar", "bar,baz"): - with pytest.raises(ValueError): - BoolValuedVariant("foo", v) def test_satisfies(self): a = BoolValuedVariant("foo", True) b = BoolValuedVariant("foo", False) c = BoolValuedVariant("fee", False) - d = BoolValuedVariant("foo", "True") + d = BoolValuedVariant("foo", True) # concrete, different values do not satisfy each other assert not a.satisfies(b) and not b.satisfies(a) @@ -325,7 +282,7 @@ def test_intersects(self): a = BoolValuedVariant("foo", True) b = BoolValuedVariant("fee", True) c = BoolValuedVariant("foo", False) - d = BoolValuedVariant("foo", "True") + d = BoolValuedVariant("foo", True) # concrete, different values do not intersect each other assert not a.intersects(b) and not b.intersects(a) @@ -347,7 +304,7 @@ def test_intersects(self): def test_constrain(self): # Try to constrain on a value equal to self - a = BoolValuedVariant("foo", "True") + a = BoolValuedVariant("foo", True) b = BoolValuedVariant("foo", True) assert not a.constrain(b) @@ -375,24 +332,24 @@ def test_constrain(self): assert a == BoolValuedVariant("foo", True) def test_yaml_entry(self): - a = BoolValuedVariant("foo", "True") + a = BoolValuedVariant("foo", True) expected = ("foo", True) assert a.yaml_entry() == expected - a = BoolValuedVariant("foo", "False") + a = BoolValuedVariant("foo", False) expected = ("foo", False) assert a.yaml_entry() == expected def test_from_node_dict(): - a = MultiValuedVariant.from_node_dict("foo", ["bar"]) - assert type(a) is MultiValuedVariant + a = VariantValue.from_node_dict("foo", ["bar"]) + assert a.type == spack.variant.VariantType.MULTI - a = MultiValuedVariant.from_node_dict("foo", "bar") - assert type(a) is SingleValuedVariant + a = VariantValue.from_node_dict("foo", "bar") + assert a.type == spack.variant.VariantType.SINGLE - a = MultiValuedVariant.from_node_dict("foo", "true") - assert type(a) is BoolValuedVariant + a = VariantValue.from_node_dict("foo", "true") + assert a.type == spack.variant.VariantType.BOOL class TestVariant: @@ -406,7 +363,7 @@ def test_validation(self): # Multiple values are not allowed with pytest.raises(MultipleValuesInExclusiveVariantError): - vspec.value = "bar,baz" + vspec.set("bar", "baz") # Inconsistent vspec vspec.name = "FOO" @@ -415,10 +372,10 @@ def test_validation(self): # Valid multi-value vspec a.multi = True - vspec = a.make_variant("bar,baz") + vspec = a.make_variant("bar", "baz") a.validate_or_raise(vspec, "test-package") # Add an invalid value - vspec.value = "bar,baz,barbaz" + vspec.set("bar", "baz", "barbaz") with pytest.raises(InvalidVariantValueError): a.validate_or_raise(vspec, "test-package") @@ -429,12 +386,12 @@ def validator(x): except ValueError: return False - a = Variant("foo", default=1024, description="", values=validator, multi=False) + a = Variant("foo", default="1024", description="", values=validator, multi=False) vspec = a.make_default() a.validate_or_raise(vspec, "test-package") - vspec.value = 2056 + vspec.set("2056") a.validate_or_raise(vspec, "test-package") - vspec.value = "foo" + vspec.set("foo") with pytest.raises(InvalidVariantValueError): a.validate_or_raise(vspec, "test-package") @@ -464,9 +421,9 @@ def test_invalid_values(self) -> None: a["foo"] = 2 # Duplicate variant - a["foo"] = MultiValuedVariant("foo", "bar,baz") + a["foo"] = MultiValuedVariant("foo", ("bar", "baz")) with pytest.raises(DuplicateVariantError): - a["foo"] = MultiValuedVariant("foo", "bar") + a["foo"] = MultiValuedVariant("foo", ("bar",)) with pytest.raises(DuplicateVariantError): a["foo"] = SingleValuedVariant("foo", "bar") @@ -476,7 +433,7 @@ def test_invalid_values(self) -> None: # Non matching names between key and vspec.name with pytest.raises(KeyError): - a["bar"] = MultiValuedVariant("foo", "bar") + a["bar"] = MultiValuedVariant("foo", ("bar",)) def test_set_item(self) -> None: # Check that all the three types of variants are accepted @@ -484,7 +441,7 @@ def test_set_item(self) -> None: a["foo"] = BoolValuedVariant("foo", True) a["bar"] = SingleValuedVariant("bar", "baz") - a["foobar"] = MultiValuedVariant("foobar", "a, b, c, d, e") + a["foobar"] = MultiValuedVariant("foobar", ("a", "b", "c", "d", "e")) def test_substitute(self) -> None: # Check substitution of a key that exists @@ -500,13 +457,13 @@ def test_substitute(self) -> None: def test_satisfies_and_constrain(self) -> None: # foo=bar foobar=fee feebar=foo a = VariantMap(Spec()) - a["foo"] = MultiValuedVariant("foo", "bar") + a["foo"] = MultiValuedVariant("foo", ("bar",)) a["foobar"] = SingleValuedVariant("foobar", "fee") a["feebar"] = SingleValuedVariant("feebar", "foo") # foo=bar,baz foobar=fee shared=True b = VariantMap(Spec()) - b["foo"] = MultiValuedVariant("foo", "bar, baz") + b["foo"] = MultiValuedVariant("foo", ("bar", "baz")) b["foobar"] = SingleValuedVariant("foobar", "fee") b["shared"] = BoolValuedVariant("shared", True) @@ -516,7 +473,7 @@ def test_satisfies_and_constrain(self) -> None: # foo=bar,baz foobar=fee feebar=foo shared=True c = VariantMap(Spec()) - c["foo"] = MultiValuedVariant("foo", "bar, baz") + c["foo"] = MultiValuedVariant("foo", ("bar", "baz")) c["foobar"] = SingleValuedVariant("foobar", "fee") c["feebar"] = SingleValuedVariant("feebar", "foo") c["shared"] = BoolValuedVariant("shared", True) @@ -529,14 +486,14 @@ def test_copy(self) -> None: a = VariantMap(Spec()) a["foo"] = BoolValuedVariant("foo", True) a["bar"] = SingleValuedVariant("bar", "baz") - a["foobar"] = MultiValuedVariant("foobar", "a, b, c, d, e") + a["foobar"] = MultiValuedVariant("foobar", ("a", "b", "c", "d", "e")) c = a.copy() assert a == c def test_str(self) -> None: c = VariantMap(Spec()) - c["foo"] = MultiValuedVariant("foo", "bar, baz") + c["foo"] = MultiValuedVariant("foo", ("bar", "baz")) c["foobar"] = SingleValuedVariant("foobar", "fee") c["feebar"] = SingleValuedVariant("feebar", "foo") c["shared"] = BoolValuedVariant("shared", True) @@ -635,10 +592,10 @@ def test_wild_card_valued_variants_equivalent_to_str(): several_arbitrary_values = ("doe", "re", "mi") # "*" case - wild_output = wild_var.make_variant(several_arbitrary_values) + wild_output = wild_var.make_variant(*several_arbitrary_values) wild_var.validate_or_raise(wild_output, "test-package") # str case - str_output = str_var.make_variant(several_arbitrary_values) + str_output = str_var.make_variant(*several_arbitrary_values) str_var.validate_or_raise(str_output, "test-package") # equivalence each instance already validated assert str_output.value == wild_output.value @@ -760,21 +717,21 @@ def test_concretize_variant_default_with_multiple_defs( "spec,variant_name,narrowed_type", [ # dev_path is a special case - ("foo dev_path=/path/to/source", "dev_path", SingleValuedVariant), + ("foo dev_path=/path/to/source", "dev_path", spack.variant.VariantType.SINGLE), # reserved name: won't be touched - ("foo patches=2349dc44", "patches", VariantBase), + ("foo patches=2349dc44", "patches", spack.variant.VariantType.MULTI), # simple case -- one definition applies - ("variant-values@1.0 v=foo", "v", SingleValuedVariant), + ("variant-values@1.0 v=foo", "v", spack.variant.VariantType.SINGLE), # simple, but with bool valued variant - ("pkg-a bvv=true", "bvv", BoolValuedVariant), + ("pkg-a bvv=true", "bvv", spack.variant.VariantType.BOOL), # takes the second definition, which overrides the single-valued one - ("variant-values@2.0 v=bar", "v", MultiValuedVariant), + ("variant-values@2.0 v=bar", "v", spack.variant.VariantType.MULTI), ], ) def test_substitute_abstract_variants_narrowing(mock_packages, spec, variant_name, narrowed_type): spec = Spec(spec) spack.spec.substitute_abstract_variants(spec) - assert type(spec.variants[variant_name]) is narrowed_type + assert spec.variants[variant_name].type == narrowed_type def test_substitute_abstract_variants_failure(mock_packages): @@ -920,3 +877,12 @@ def test_patches_variant(): assert not Spec("patches:=abcdef").satisfies("patches:=ab") assert not Spec("patches:=abcdef,xyz").satisfies("patches:=abc,xyz") assert not Spec("patches:=abcdef").satisfies("patches:=abcdefghi") + + +def test_constrain_narrowing(): + s = Spec("foo=*") + assert s.variants["foo"].type == spack.variant.VariantType.MULTI + assert not s.variants["foo"].concrete + s.constrain("+foo") + assert s.variants["foo"].type == spack.variant.VariantType.BOOL + assert s.variants["foo"].concrete diff --git a/lib/spack/spack/variant.py b/lib/spack/spack/variant.py index e9700ed9466..cd4de737779 100644 --- a/lib/spack/spack/variant.py +++ b/lib/spack/spack/variant.py @@ -10,7 +10,6 @@ import functools import inspect import itertools -import re from typing import Any, Callable, Collection, Iterable, List, Optional, Tuple, Type, Union import llnl.util.lang as lang @@ -33,24 +32,22 @@ "target", } -special_variant_values = [None, "none", "*"] - -class VariantType(enum.Enum): +class VariantType(enum.IntEnum): """Enum representing the three concrete variant types.""" - MULTI = "multi" - BOOL = "bool" - SINGLE = "single" + BOOL = 1 + SINGLE = 2 + MULTI = 3 @property - def variant_class(self) -> Type: - if self is self.MULTI: - return MultiValuedVariant - elif self is self.BOOL: - return BoolValuedVariant - else: - return SingleValuedVariant + def string(self) -> str: + """Convert the variant type to a string.""" + if self == VariantType.BOOL: + return "bool" + elif self == VariantType.SINGLE: + return "single" + return "multi" class Variant: @@ -134,7 +131,7 @@ def isa_type(v): self.sticky = sticky self.precedence = precedence - def validate_or_raise(self, vspec: "VariantBase", pkg_name: str): + def validate_or_raise(self, vspec: "VariantValue", pkg_name: str): """Validate a variant spec against this package variant. Raises an exception if any error is found. @@ -156,7 +153,7 @@ def validate_or_raise(self, vspec: "VariantBase", pkg_name: str): raise InconsistentValidationError(vspec, self) # If the value is exclusive there must be at most one - value = vspec.value_as_tuple + value = vspec.values if not self.multi and len(value) != 1: raise MultipleValuesInExclusiveVariantError(vspec, pkg_name) @@ -191,27 +188,15 @@ def allowed_values(self): v = docstring if docstring else "" return v - def make_default(self): - """Factory that creates a variant holding the default value. + def make_default(self) -> "VariantValue": + """Factory that creates a variant holding the default value(s).""" + variant = VariantValue.from_string_or_bool(self.name, self.default) + variant.type = self.variant_type + return variant - Returns: - MultiValuedVariant or SingleValuedVariant or BoolValuedVariant: - instance of the proper variant - """ - return self.make_variant(self.default) - - def make_variant(self, value: Union[str, bool]) -> "VariantBase": - """Factory that creates a variant holding the value passed as - a parameter. - - Args: - value: value that will be hold by the variant - - Returns: - MultiValuedVariant or SingleValuedVariant or BoolValuedVariant: - instance of the proper variant - """ - return self.variant_type.variant_class(self.name, value) + def make_variant(self, *value: Union[str, bool]) -> "VariantValue": + """Factory that creates a variant holding the value(s) passed.""" + return VariantValue(self.variant_type, self.name, value) @property def variant_type(self) -> VariantType: @@ -254,116 +239,148 @@ def _flatten(values) -> Collection: #: Type for value of a variant -ValueType = Union[str, bool, Tuple[Union[str, bool], ...]] +ValueType = Tuple[Union[bool, str], ...] #: Type of variant value when output for JSON, YAML, etc. SerializedValueType = Union[str, bool, List[Union[str, bool]]] @lang.lazy_lexicographic_ordering -class VariantBase: - """A BaseVariant corresponds to a spec string of the form ``foo=bar`` or ``foo=bar,baz``. - It is a constraint on the spec and abstract in the sense that it must have **at least** these - values -- concretization may add more values.""" +class VariantValue: + """A VariantValue is a key-value pair that represents a variant. It can have zero or more + values. Values have set semantics, so they are unordered and unique. The variant type can + be narrowed from multi to single to boolean, this limits the number of values that can be + stored in the variant. Multi-valued variants can either be concrete or abstract: abstract + means that the variant takes at least the values specified, but may take more when concretized. + Concrete means that the variant takes exactly the values specified. Lastly, a variant can be + marked as propagating, which means that it should be propagated to dependencies.""" name: str propagate: bool - _value: ValueType - _original_value: Any + concrete: bool + type: VariantType + _values: ValueType - def __init__(self, name: str, value: ValueType, propagate: bool = False) -> None: + slots = ("name", "propagate", "concrete", "type", "_values") + + def __init__( + self, + type: VariantType, + name: str, + value: ValueType, + *, + propagate: bool = False, + concrete: bool = False, + ) -> None: self.name = name + self.type = type self.propagate = propagate - self.concrete = False + # only multi-valued variants can be abstract + self.concrete = concrete or type in (VariantType.BOOL, VariantType.SINGLE) # Invokes property setter - self.value = value + self.set(*value) @staticmethod def from_node_dict( name: str, value: Union[str, List[str]], *, propagate: bool = False, abstract: bool = False - ) -> "VariantBase": + ) -> "VariantValue": """Reconstruct a variant from a node dict.""" if isinstance(value, list): - constructor = VariantBase if abstract else MultiValuedVariant - mvar = constructor(name, (), propagate=propagate) - mvar._value = tuple(value) - mvar._original_value = mvar._value - return mvar + return VariantValue( + VariantType.MULTI, name, tuple(value), propagate=propagate, concrete=not abstract + ) + # todo: is this necessary? not literal true / false in json/yaml? elif str(value).upper() == "TRUE" or str(value).upper() == "FALSE": - return BoolValuedVariant(name, value, propagate=propagate) + return VariantValue( + VariantType.BOOL, name, (str(value).upper() == "TRUE",), propagate=propagate + ) - return SingleValuedVariant(name, value, propagate=propagate) + return VariantValue(VariantType.SINGLE, name, (value,), propagate=propagate) + + @staticmethod + def from_string_or_bool( + name: str, value: Union[str, bool], *, propagate: bool = False, concrete: bool = False + ) -> "VariantValue": + if value is True or value is False: + return VariantValue(VariantType.BOOL, name, (value,), propagate=propagate) + + elif value.upper() in ("TRUE", "FALSE"): + return VariantValue( + VariantType.BOOL, name, (value.upper() == "TRUE",), propagate=propagate + ) + + elif value == "*": + return VariantValue(VariantType.MULTI, name, (), propagate=propagate) + + return VariantValue( + VariantType.MULTI, + name, + tuple(value.split(",")), + propagate=propagate, + concrete=concrete, + ) + + @staticmethod + def from_concretizer(name: str, value: str, type: str) -> "VariantValue": + """Reconstruct a variant from concretizer output.""" + if type == "bool": + return VariantValue(VariantType.BOOL, name, (value == "True",)) + elif type == "multi": + return VariantValue(VariantType.MULTI, name, (value,), concrete=True) + else: + return VariantValue(VariantType.SINGLE, name, (value,)) def yaml_entry(self) -> Tuple[str, SerializedValueType]: - """Returns a key, value tuple suitable to be an entry in a yaml dict. + """Returns a (key, value) tuple suitable to be an entry in a yaml dict. Returns: tuple: (name, value_representation) """ - return self.name, list(self.value_as_tuple) + if self.type == VariantType.MULTI: + return self.name, list(self.values) + return self.name, self.values[0] @property - def value_as_tuple(self) -> Tuple[Union[bool, str], ...]: - """Getter for self.value that always returns a Tuple (even for single valued variants). - - This makes it easy to iterate over possible values. - """ - if isinstance(self._value, (bool, str)): - return (self._value,) - return self._value + def values(self) -> ValueType: + return self._values @property - def value(self) -> ValueType: - """Returns a tuple of strings containing the values stored in - the variant. + def value(self) -> Union[ValueType, bool, str]: + return self._values[0] if self.type != VariantType.MULTI else self._values - Returns: - tuple: values stored in the variant - """ - return self._value + def set(self, *value: Union[bool, str]) -> None: + """Set the value(s) of the variant.""" + if len(value) > 1: + value = tuple(sorted(set(value))) - @value.setter - def value(self, value: ValueType) -> None: - self._value_setter(value) + if self.type != VariantType.MULTI: + if len(value) != 1: + raise MultipleValuesInExclusiveVariantError(self) + unwrapped = value[0] + if self.type == VariantType.BOOL and unwrapped not in (True, False): + raise ValueError( + f"cannot set a boolean variant to a value that is not a boolean: {unwrapped}" + ) - def _value_setter(self, value: ValueType) -> None: - # Store the original value - self._original_value = value + if "*" in value: + raise InvalidVariantValueError("cannot use reserved value '*'") - if value == "*": - self._value = () - return - - if not isinstance(value, (tuple, list)): - # Store a tuple of CSV string representations - # Tuple is necessary here instead of list because the - # values need to be hashed - value = tuple(re.split(r"\s*,\s*", str(value))) - - for val in special_variant_values: - if val in value and len(value) > 1: - msg = "'%s' cannot be combined with other variant" % val - msg += " values." - raise InvalidVariantValueCombinationError(msg) - - # With multi-value variants it is necessary - # to remove duplicates and give an order - # to a set - self._value = tuple(sorted(set(value))) + self._values = value def _cmp_iter(self) -> Iterable: yield self.name yield self.propagate - yield from (str(v) for v in self.value_as_tuple) + yield self.concrete + yield from (str(v) for v in self.values) - def copy(self) -> "VariantBase": - variant = type(self)(self.name, self._original_value, self.propagate) - variant.concrete = self.concrete - return variant + def copy(self) -> "VariantValue": + return VariantValue( + self.type, self.name, self.values, propagate=self.propagate, concrete=self.concrete + ) - def satisfies(self, other: "VariantBase") -> bool: + def satisfies(self, other: "VariantValue") -> bool: """The lhs satisfies the rhs if all possible concretizations of lhs are also possible concretizations of rhs.""" if self.name != other.name: @@ -376,138 +393,90 @@ def satisfies(self, other: "VariantBase") -> bool: if self.name == "patches": return all( isinstance(v, str) - and any(isinstance(w, str) and w.startswith(v) for w in self.value_as_tuple) - for v in other.value_as_tuple + and any(isinstance(w, str) and w.startswith(v) for w in self.values) + for v in other.values ) - return all(v in self for v in other.value_as_tuple) + return all(v in self for v in other.values) if self.concrete: # both concrete: they must be equal - return self.value_as_tuple == other.value_as_tuple + return self.values == other.values return False - def intersects(self, other: "VariantBase") -> bool: + def intersects(self, other: "VariantValue") -> bool: """True iff there exists a concretization that satisfies both lhs and rhs.""" if self.name != other.name: return False if self.concrete: if other.concrete: - return self.value_as_tuple == other.value_as_tuple - return all(v in self for v in other.value_as_tuple) + return self.values == other.values + return all(v in self for v in other.values) if other.concrete: - return all(v in other for v in self.value_as_tuple) + return all(v in other for v in self.values) # both abstract: the union is a valid concretization of both return True - def constrain(self, other: "VariantBase") -> bool: + def constrain(self, other: "VariantValue") -> bool: """Constrain self with other if they intersect. Returns true iff self was changed.""" if not self.intersects(other): raise UnsatisfiableVariantSpecError(self, other) - old_value = self.value - values = list(sorted({*self.value_as_tuple, *other.value_as_tuple})) - self._value_setter(",".join(str(v) for v in values)) - changed = old_value != self.value + old_values = self.values + self.set(*self.values, *other.values) + changed = old_values != self.values if self.propagate and not other.propagate: self.propagate = False changed = True if not self.concrete and other.concrete: self.concrete = True changed = True + if self.type > other.type: + self.type = other.type + changed = True return changed - def __contains__(self, item: Union[str, bool]) -> bool: - return item in self.value_as_tuple + def append(self, value: Union[str, bool]) -> None: + self.set(*self.values, value) - def __repr__(self) -> str: - return f"{type(self).__name__}({repr(self.name)}, {repr(self._original_value)})" + def __contains__(self, item: Union[str, bool]) -> bool: + return item in self.values def __str__(self) -> str: - concrete = ":" if self.concrete else "" + # boolean variants are printed +foo or ~foo + if self.type == VariantType.BOOL: + sigil = "+" if self.value else "~" + if self.propagate: + sigil *= 2 + return f"{sigil}{self.name}" + + # concrete multi-valued foo:=bar,baz + concrete = ":" if self.type == VariantType.MULTI and self.concrete else "" delim = "==" if self.propagate else "=" - values_tuple = self.value_as_tuple - if values_tuple: - value_str = ",".join(str(v) for v in values_tuple) - else: + if not self.values: value_str = "*" + elif self.name == "patches" and self.concrete: + value_str = ",".join(str(x)[:7] for x in self.values) + else: + value_str = ",".join(str(x) for x in self.values) return f"{self.name}{concrete}{delim}{spack.spec_parser.quote_if_needed(value_str)}" - -class MultiValuedVariant(VariantBase): - def __init__(self, name, value, propagate=False): - super().__init__(name, value, propagate) - self.concrete = True - - def append(self, value: Union[str, bool]) -> None: - """Add another value to this multi-valued variant.""" - self._value = tuple(sorted((value,) + self.value_as_tuple)) - self._original_value = ",".join(str(v) for v in self._value) - - def __str__(self) -> str: - # Special-case patches to not print the full 64 character sha256 - if self.name == "patches": - values_str = ",".join(str(x)[:7] for x in self.value_as_tuple) - else: - values_str = ",".join(str(x) for x in self.value_as_tuple) - - delim = "==" if self.propagate else "=" - return f"{self.name}:{delim}{spack.spec_parser.quote_if_needed(values_str)}" + def __repr__(self): + return ( + f"VariantValue({self.type!r}, {self.name!r}, {self.values!r}, " + f"propagate={self.propagate!r}, concrete={self.concrete!r})" + ) -class SingleValuedVariant(VariantBase): - def __init__(self, name, value, propagate=False): - super().__init__(name, value, propagate) - self.concrete = True - - def _value_setter(self, value: ValueType) -> None: - # Treat the value as a multi-valued variant - super()._value_setter(value) - - # Then check if there's only a single value - values = self.value_as_tuple - if len(values) != 1: - raise MultipleValuesInExclusiveVariantError(self) - - self._value = values[0] - - def __contains__(self, item: ValueType) -> bool: - return item == self.value - - def yaml_entry(self) -> Tuple[str, SerializedValueType]: - assert isinstance(self.value, (bool, str)) - return self.name, self.value - - def __str__(self) -> str: - delim = "==" if self.propagate else "=" - return f"{self.name}{delim}{spack.spec_parser.quote_if_needed(str(self.value))}" +def MultiValuedVariant(name: str, value: ValueType, propagate: bool = False) -> VariantValue: + return VariantValue(VariantType.MULTI, name, value, propagate=propagate, concrete=True) -class BoolValuedVariant(SingleValuedVariant): - def __init__(self, name, value, propagate=False): - super().__init__(name, value, propagate) - self.concrete = True +def SingleValuedVariant( + name: str, value: Union[bool, str], propagate: bool = False +) -> VariantValue: + return VariantValue(VariantType.SINGLE, name, (value,), propagate=propagate) - def _value_setter(self, value: ValueType) -> None: - # Check the string representation of the value and turn - # it to a boolean - if str(value).upper() == "TRUE": - self._original_value = value - self._value = True - elif str(value).upper() == "FALSE": - self._original_value = value - self._value = False - else: - raise ValueError( - f'cannot construct a BoolValuedVariant for "{self.name}" from ' - "a value that does not represent a bool" - ) - def __contains__(self, item: ValueType) -> bool: - return item is self.value - - def __str__(self) -> str: - sigil = "+" if self.value else "~" - if self.propagate: - sigil *= 2 - return f"{sigil}{self.name}" +def BoolValuedVariant(name: str, value: bool, propagate: bool = False) -> VariantValue: + return VariantValue(VariantType.BOOL, name, (value,), propagate=propagate) # The class below inherit from Sequence to disguise as a tuple and comply @@ -714,7 +683,7 @@ def __lt__(self, other): def prevalidate_variant_value( pkg_cls: "Type[spack.package_base.PackageBase]", - variant: VariantBase, + variant: VariantValue, spec: Optional["spack.spec.Spec"] = None, strict: bool = False, ) -> List[Variant]: @@ -735,8 +704,8 @@ def prevalidate_variant_value( list of variant definitions that will accept the given value. List will be empty only if the variant is a reserved variant. """ - # don't validate wildcards or variants with reserved names - if variant.value == ("*",) or variant.name in RESERVED_NAMES or variant.propagate: + # do not validate non-user variants or optional variants + if variant.name in RESERVED_NAMES or variant.propagate: return [] # raise if there is no definition at all @@ -819,17 +788,13 @@ class MultipleValuesInExclusiveVariantError(spack.error.SpecError, ValueError): only one. """ - def __init__(self, variant: VariantBase, pkg_name: Optional[str] = None): + def __init__(self, variant: VariantValue, pkg_name: Optional[str] = None): pkg_info = "" if pkg_name is None else f" in package '{pkg_name}'" msg = f"multiple values are not allowed for variant '{variant.name}'{pkg_info}" super().__init__(msg.format(variant, pkg_info)) -class InvalidVariantValueCombinationError(spack.error.SpecError): - """Raised when a variant has values '*' or 'none' with other values.""" - - class InvalidVariantValueError(spack.error.SpecError): """Raised when variants have invalid values."""