Refactor to improve spec format speed (#43712)

When looking at where we spend our time in solver setup, I noticed a fair bit of time is spent
in `Spec.format()`, and `Spec.format()` is a pretty old, slow, convoluted method.

This PR does a number of things:
- [x] Consolidate most of what was being done manually with a character loop and several
      regexes into a single regex.
- [x] Precompile regexes where we keep them 
- [x] Remove the `transform=` argument to `Spec.format()` which was only used in one 
      place in the code (modules) to uppercase env var names, but added a lot of complexity
- [x] Avoid escaping and colorizing specs unless necessary
- [x] Refactor a lot of the colorization logic to avoid unnecessary object construction
- [x] Add type hints and remove some spots in the code where we were using nonexistent
      arguments to `format()`.
- [x] Add trivial cases to `__str__` in `VariantMap` and `VersionList` to avoid sorting
- [x] Avoid calling `isinstance()` in the main loop of `Spec.format()`
- [x] Don't bother constructing a `string` representation for the result of `_prev_version`
      as it is only used for comparisons.

In my timings (on all the specs formatted in a solve of `hdf5`), this is over 2.67x faster than the 
original `format()`, and it seems to reduce setup time by around a second (for `hdf5`).
This commit is contained in:
Todd Gamblin 2024-04-23 10:52:15 -07:00 committed by GitHub
parent 978c20f35a
commit aa0825d642
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 226 additions and 270 deletions

View File

@ -12,7 +12,7 @@
import traceback import traceback
from datetime import datetime from datetime import datetime
from sys import platform as _platform from sys import platform as _platform
from typing import NoReturn from typing import Any, NoReturn
if _platform != "win32": if _platform != "win32":
import fcntl import fcntl
@ -158,21 +158,22 @@ def get_timestamp(force=False):
return "" return ""
def msg(message, *args, **kwargs): def msg(message: Any, *args: Any, newline: bool = True) -> None:
if not msg_enabled(): if not msg_enabled():
return return
if isinstance(message, Exception): if isinstance(message, Exception):
message = "%s: %s" % (message.__class__.__name__, str(message)) message = f"{message.__class__.__name__}: {message}"
else:
message = str(message)
newline = kwargs.get("newline", True)
st_text = "" st_text = ""
if _stacktrace: if _stacktrace:
st_text = process_stacktrace(2) st_text = process_stacktrace(2)
if newline:
cprint("@*b{%s==>} %s%s" % (st_text, get_timestamp(), cescape(_output_filter(message)))) nl = "\n" if newline else ""
else: cwrite(f"@*b{{{st_text}==>}} {get_timestamp()}{cescape(_output_filter(message))}{nl}")
cwrite("@*b{%s==>} %s%s" % (st_text, get_timestamp(), cescape(_output_filter(message))))
for arg in args: for arg in args:
print(indent + _output_filter(str(arg))) print(indent + _output_filter(str(arg)))

View File

@ -62,6 +62,7 @@
import re import re
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional
class ColorParseError(Exception): class ColorParseError(Exception):
@ -95,7 +96,7 @@ def __init__(self, message):
} # white } # white
# Regex to be used for color formatting # Regex to be used for color formatting
color_re = r"@(?:@|\.|([*_])?([a-zA-Z])?(?:{((?:[^}]|}})*)})?)" COLOR_RE = re.compile(r"@(?:(@)|(\.)|([*_])?([a-zA-Z])?(?:{((?:[^}]|}})*)})?)")
# Mapping from color arguments to values for tty.set_color # Mapping from color arguments to values for tty.set_color
color_when_values = {"always": True, "auto": None, "never": False} color_when_values = {"always": True, "auto": None, "never": False}
@ -203,77 +204,64 @@ def color_when(value):
set_color_when(old_value) set_color_when(old_value)
class match_to_ansi: def _escape(s: str, color: bool, enclose: bool, zsh: bool) -> str:
def __init__(self, color=True, enclose=False, zsh=False): """Returns a TTY escape sequence for a color"""
self.color = _color_when_value(color) if color:
self.enclose = enclose if zsh:
self.zsh = zsh result = rf"\e[0;{s}m"
def escape(self, s):
"""Returns a TTY escape sequence for a color"""
if self.color:
if self.zsh:
result = rf"\e[0;{s}m"
else:
result = f"\033[{s}m"
if self.enclose:
result = rf"\[{result}\]"
return result
else: else:
return "" result = f"\033[{s}m"
def __call__(self, match): if enclose:
"""Convert a match object generated by ``color_re`` into an ansi result = rf"\[{result}\]"
color code. This can be used as a handler in ``re.sub``.
"""
style, color, text = match.groups()
m = match.group(0)
if m == "@@": return result
return "@" else:
elif m == "@.": return ""
return self.escape(0)
elif m == "@":
raise ColorParseError("Incomplete color format: '%s' in %s" % (m, match.string))
string = styles[style]
if color:
if color not in colors:
raise ColorParseError(
"Invalid color specifier: '%s' in '%s'" % (color, match.string)
)
string += ";" + str(colors[color])
colored_text = ""
if text:
colored_text = text + self.escape(0)
return self.escape(string) + colored_text
def colorize(string, **kwargs): def colorize(
string: str, color: Optional[bool] = None, enclose: bool = False, zsh: bool = False
) -> str:
"""Replace all color expressions in a string with ANSI control codes. """Replace all color expressions in a string with ANSI control codes.
Args: Args:
string (str): The string to replace string: The string to replace
Returns: Returns:
str: The filtered string The filtered string
Keyword Arguments: Keyword Arguments:
color (bool): If False, output will be plain text without control color: If False, output will be plain text without control codes, for output to
codes, for output to non-console devices. non-console devices (default: automatically choose color or not)
enclose (bool): If True, enclose ansi color sequences with enclose: If True, enclose ansi color sequences with
square brackets to prevent misestimation of terminal width. square brackets to prevent misestimation of terminal width.
zsh (bool): If True, use zsh ansi codes instead of bash ones (for variables like PS1) zsh: If True, use zsh ansi codes instead of bash ones (for variables like PS1)
""" """
color = _color_when_value(kwargs.get("color", get_color_when())) color = color if color is not None else get_color_when()
zsh = kwargs.get("zsh", False)
string = re.sub(color_re, match_to_ansi(color, kwargs.get("enclose")), string, zsh) def match_to_ansi(match):
string = string.replace("}}", "}") """Convert a match object generated by ``COLOR_RE`` into an ansi
return string color code. This can be used as a handler in ``re.sub``.
"""
escaped_at, dot, style, color_code, text = match.groups()
if escaped_at:
return "@"
elif dot:
return _escape(0, color, enclose, zsh)
elif not (style or color_code):
raise ColorParseError(
f"Incomplete color format: '{match.group(0)}' in '{match.string}'"
)
ansi_code = _escape(f"{styles[style]};{colors.get(color_code, '')}", color, enclose, zsh)
if text:
return f"{ansi_code}{text}{_escape(0, color, enclose, zsh)}"
else:
return ansi_code
return COLOR_RE.sub(match_to_ansi, string).replace("}}", "}")
def clen(string): def clen(string):
@ -305,7 +293,7 @@ def cprint(string, stream=None, color=None):
cwrite(string + "\n", stream, color) cwrite(string + "\n", stream, color)
def cescape(string): def cescape(string: str) -> str:
"""Escapes special characters needed for color codes. """Escapes special characters needed for color codes.
Replaces the following symbols with their equivalent literal forms: Replaces the following symbols with their equivalent literal forms:
@ -321,10 +309,7 @@ def cescape(string):
Returns: Returns:
(str): the string with color codes escaped (str): the string with color codes escaped
""" """
string = str(string) return string.replace("@", "@@").replace("}", "}}")
string = string.replace("@", "@@")
string = string.replace("}", "}}")
return string
class ColorStream: class ColorStream:

View File

@ -263,8 +263,8 @@ def _fmt_name_and_default(variant):
return color.colorize(f"@c{{{variant.name}}} @C{{[{_fmt_value(variant.default)}]}}") return color.colorize(f"@c{{{variant.name}}} @C{{[{_fmt_value(variant.default)}]}}")
def _fmt_when(when, indent): def _fmt_when(when: "spack.spec.Spec", indent: int):
return color.colorize(f"{indent * ' '}@B{{when}} {color.cescape(when)}") return color.colorize(f"{indent * ' '}@B{{when}} {color.cescape(str(when))}")
def _fmt_variant_description(variant, width, indent): def _fmt_variant_description(variant, width, indent):
@ -441,7 +441,7 @@ def get_url(version):
return "No URL" return "No URL"
url = get_url(preferred) if pkg.has_code else "" url = get_url(preferred) if pkg.has_code else ""
line = version(" {0}".format(pad(preferred))) + color.cescape(url) line = version(" {0}".format(pad(preferred))) + color.cescape(str(url))
color.cwrite(line) color.cwrite(line)
print() print()
@ -464,7 +464,7 @@ def get_url(version):
continue continue
for v, url in vers: for v, url in vers:
line = version(" {0}".format(pad(v))) + color.cescape(url) line = version(" {0}".format(pad(v))) + color.cescape(str(url))
color.cprint(line) color.cprint(line)
@ -475,10 +475,7 @@ def print_virtuals(pkg, args):
color.cprint(section_title("Virtual Packages: ")) color.cprint(section_title("Virtual Packages: "))
if pkg.provided: if pkg.provided:
for when, specs in reversed(sorted(pkg.provided.items())): for when, specs in reversed(sorted(pkg.provided.items())):
line = " %s provides %s" % ( line = " %s provides %s" % (when.cformat(), ", ".join(s.cformat() for s in specs))
when.colorized(),
", ".join(s.colorized() for s in specs),
)
print(line) print(line)
else: else:
@ -497,7 +494,9 @@ def print_licenses(pkg, args):
pad = padder(pkg.licenses, 4) pad = padder(pkg.licenses, 4)
for when_spec in pkg.licenses: for when_spec in pkg.licenses:
license_identifier = pkg.licenses[when_spec] license_identifier = pkg.licenses[when_spec]
line = license(" {0}".format(pad(license_identifier))) + color.cescape(when_spec) line = license(" {0}".format(pad(license_identifier))) + color.cescape(
str(when_spec)
)
color.cprint(line) color.cprint(line)

View File

@ -83,6 +83,17 @@ def configuration(module_set_name):
) )
_FORMAT_STRING_RE = re.compile(r"({[^}]*})")
def _format_env_var_name(spec, var_name_fmt):
"""Format the variable name, but uppercase any formatted fields."""
fmt_parts = _FORMAT_STRING_RE.split(var_name_fmt)
return "".join(
spec.format(part).upper() if _FORMAT_STRING_RE.match(part) else part for part in fmt_parts
)
def _check_tokens_are_valid(format_string, message): def _check_tokens_are_valid(format_string, message):
"""Checks that the tokens used in the format string are valid in """Checks that the tokens used in the format string are valid in
the context of module file and environment variable naming. the context of module file and environment variable naming.
@ -737,20 +748,12 @@ def environment_modifications(self):
exclude = self.conf.exclude_env_vars exclude = self.conf.exclude_env_vars
# We may have tokens to substitute in environment commands # We may have tokens to substitute in environment commands
# Prepare a suitable transformation dictionary for the names
# of the environment variables. This means turn the valid
# tokens uppercase.
transform = {}
for token in _valid_tokens:
transform[token] = lambda s, string: str.upper(string)
for x in env: for x in env:
# Ensure all the tokens are valid in this context # Ensure all the tokens are valid in this context
msg = "some tokens cannot be expanded in an environment variable name" msg = "some tokens cannot be expanded in an environment variable name"
_check_tokens_are_valid(x.name, message=msg) _check_tokens_are_valid(x.name, message=msg)
# Transform them x.name = _format_env_var_name(self.spec, x.name)
x.name = self.spec.format(x.name, transform=transform)
if self.modification_needs_formatting(x): if self.modification_needs_formatting(x):
try: try:
# Not every command has a value # Not every command has a value

View File

@ -51,7 +51,6 @@
import collections import collections
import collections.abc import collections.abc
import enum import enum
import io
import itertools import itertools
import os import os
import pathlib import pathlib
@ -59,7 +58,7 @@
import re import re
import socket import socket
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Match, Optional, Set, Tuple, Union
import llnl.path import llnl.path
import llnl.string import llnl.string
@ -121,36 +120,44 @@
"SpecDeprecatedError", "SpecDeprecatedError",
] ]
SPEC_FORMAT_RE = re.compile(
r"(?:" # this is one big or, with matches ordered by priority
# OPTION 1: escaped character (needs to be first to catch opening \{)
# Note that an unterminated \ at the end of a string is left untouched
r"(?:\\(.))"
r"|" # or
# OPTION 2: an actual format string
r"{" # non-escaped open brace {
r"([%@/]|arch=)?" # optional sigil (to print sigil in color)
r"(?:\^([^}\.]+)\.)?" # optional ^depname. (to get attr from dependency)
# after the sigil or depname, we can have a hash expression or another attribute
r"(?:" # one of
r"(hash\b)(?:\:(\d+))?" # hash followed by :<optional length>
r"|" # or
r"([^}]*)" # another attribute to format
r")" # end one of
r"(})?" # finish format string with non-escaped close brace }, or missing if not present
r"|"
# OPTION 3: mismatched close brace (option 2 would consume a matched open brace)
r"(})" # brace
r")",
re.IGNORECASE,
)
#: Valid pattern for an identifier in Spack #: Valid pattern for an identifier in Spack
IDENTIFIER_RE = r"\w[\w-]*" IDENTIFIER_RE = r"\w[\w-]*"
# Coloring of specs when using color output. Fields are printed with
# different colors to enhance readability.
# See llnl.util.tty.color for descriptions of the color codes.
COMPILER_COLOR = "@g" #: color for highlighting compilers COMPILER_COLOR = "@g" #: color for highlighting compilers
VERSION_COLOR = "@c" #: color for highlighting versions VERSION_COLOR = "@c" #: color for highlighting versions
ARCHITECTURE_COLOR = "@m" #: color for highlighting architectures ARCHITECTURE_COLOR = "@m" #: color for highlighting architectures
ENABLED_VARIANT_COLOR = "@B" #: color for highlighting enabled variants VARIANT_COLOR = "@B" #: color for highlighting variants
DISABLED_VARIANT_COLOR = "r" #: color for highlighting disabled varaints
DEPENDENCY_COLOR = "@." #: color for highlighting dependencies
HASH_COLOR = "@K" #: color for highlighting package hashes HASH_COLOR = "@K" #: color for highlighting package hashes
#: This map determines the coloring of specs when using color output.
#: We make the fields different colors to enhance readability.
#: See llnl.util.tty.color for descriptions of the color codes.
COLOR_FORMATS = {
"%": COMPILER_COLOR,
"@": VERSION_COLOR,
"=": ARCHITECTURE_COLOR,
"+": ENABLED_VARIANT_COLOR,
"~": DISABLED_VARIANT_COLOR,
"^": DEPENDENCY_COLOR,
"#": HASH_COLOR,
}
#: Regex used for splitting by spec field separators.
#: These need to be escaped to avoid metacharacters in
#: ``COLOR_FORMATS.keys()``.
_SEPARATORS = "[\\%s]" % "\\".join(COLOR_FORMATS.keys())
#: Default format for Spec.format(). This format can be round-tripped, so that: #: Default format for Spec.format(). This format can be round-tripped, so that:
#: Spec(Spec("string").format()) == Spec("string)" #: Spec(Spec("string").format()) == Spec("string)"
DEFAULT_FORMAT = ( DEFAULT_FORMAT = (
@ -193,26 +200,7 @@ class InstallStatus(enum.Enum):
missing = "@r{[-]} " missing = "@r{[-]} "
def colorize_spec(spec): # regexes used in spec formatting
"""Returns a spec colorized according to the colors specified in
COLOR_FORMATS."""
class insert_color:
def __init__(self):
self.last = None
def __call__(self, match):
# ignore compiler versions (color same as compiler)
sep = match.group(0)
if self.last == "%" and sep == "@":
return clr.cescape(sep)
self.last = sep
return "%s%s" % (COLOR_FORMATS[sep], clr.cescape(sep))
return clr.colorize(re.sub(_SEPARATORS, insert_color(), str(spec)) + "@.")
OLD_STYLE_FMT_RE = re.compile(r"\${[A-Z]+}") OLD_STYLE_FMT_RE = re.compile(r"\${[A-Z]+}")
@ -4295,10 +4283,7 @@ def deps():
yield deps yield deps
def colorized(self): def format(self, format_string: str = DEFAULT_FORMAT, color: Optional[bool] = False) -> str:
return colorize_spec(self)
def format(self, format_string=DEFAULT_FORMAT, **kwargs):
r"""Prints out particular pieces of a spec, depending on what is r"""Prints out particular pieces of a spec, depending on what is
in the format string. in the format string.
@ -4361,79 +4346,65 @@ def format(self, format_string=DEFAULT_FORMAT, **kwargs):
literal ``\`` character. literal ``\`` character.
Args: Args:
format_string (str): string containing the format to be expanded format_string: string containing the format to be expanded
color: True for colorized result; False for no color; None for auto color.
Keyword Args:
color (bool): True if returned string is colored
transform (dict): maps full-string formats to a callable \
that accepts a string and returns another one
""" """
ensure_modern_format_string(format_string) ensure_modern_format_string(format_string)
color = kwargs.get("color", False)
transform = kwargs.get("transform", {})
out = io.StringIO() def safe_color(sigil: str, string: str, color_fmt: Optional[str]) -> str:
# avoid colorizing if there is no color or the string is empty
if (color is False) or not color_fmt or not string:
return sigil + string
# escape and add the sigil here to avoid multiple concatenations
if sigil == "@":
sigil = "@@"
return clr.colorize(f"{color_fmt}{sigil}{clr.cescape(string)}@.", color=color)
def write(s, c=None): def format_attribute(match_object: Match) -> str:
f = clr.cescape(s) (esc, sig, dep, hash, hash_len, attribute, close_brace, unmatched_close_brace) = (
if c is not None: match_object.groups()
f = COLOR_FORMATS[c] + f + "@." )
clr.cwrite(f, stream=out, color=color) if esc:
return esc
elif unmatched_close_brace:
raise SpecFormatStringError(f"Unmatched close brace: '{format_string}'")
elif not close_brace:
raise SpecFormatStringError(f"Missing close brace: '{format_string}'")
def write_attribute(spec, attribute, color): current = self if dep is None else self[dep]
attribute = attribute.lower()
sig = "" # Hash attributes can return early.
if attribute.startswith(("@", "%", "/")): # NOTE: we currently treat abstract_hash like an attribute and ignore
# color sigils that are inside braces # any length associated with it. We may want to change that.
sig = attribute[0] if hash:
attribute = attribute[1:] if sig and sig != "/":
elif attribute.startswith("arch="): raise SpecFormatSigilError(sig, "DAG hashes", hash)
sig = " arch=" # include space as separator try:
attribute = attribute[5:] length = int(hash_len) if hash_len else None
except ValueError:
current = spec raise SpecFormatStringError(f"Invalid hash length: '{hash_len}'")
if attribute.startswith("^"): return safe_color(sig or "", current.dag_hash(length), HASH_COLOR)
attribute = attribute[1:]
dep, attribute = attribute.split(".", 1)
current = self[dep]
if attribute == "": if attribute == "":
raise SpecFormatStringError("Format string attributes must be non-empty") raise SpecFormatStringError("Format string attributes must be non-empty")
attribute = attribute.lower()
parts = attribute.split(".") parts = attribute.split(".")
assert parts assert parts
# check that the sigil is valid for the attribute. # check that the sigil is valid for the attribute.
if sig == "@" and parts[-1] not in ("versions", "version"): if not sig:
sig = ""
elif sig == "@" and parts[-1] not in ("versions", "version"):
raise SpecFormatSigilError(sig, "versions", attribute) raise SpecFormatSigilError(sig, "versions", attribute)
elif sig == "%" and attribute not in ("compiler", "compiler.name"): elif sig == "%" and attribute not in ("compiler", "compiler.name"):
raise SpecFormatSigilError(sig, "compilers", attribute) raise SpecFormatSigilError(sig, "compilers", attribute)
elif sig == "/" and not re.match(r"(abstract_)?hash(:\d+)?$", attribute): elif sig == "/" and attribute != "abstract_hash":
raise SpecFormatSigilError(sig, "DAG hashes", attribute) raise SpecFormatSigilError(sig, "DAG hashes", attribute)
elif sig == " arch=" and attribute not in ("architecture", "arch"): elif sig == "arch=":
raise SpecFormatSigilError(sig, "the architecture", attribute) if attribute not in ("architecture", "arch"):
raise SpecFormatSigilError(sig, "the architecture", attribute)
# find the morph function for our attribute sig = " arch=" # include space as separator
morph = transform.get(attribute, lambda s, x: x)
# Special cases for non-spec attributes and hashes.
# These must be the only non-dep component of the format attribute
if attribute == "spack_root":
write(morph(spec, spack.paths.spack_root))
return
elif attribute == "spack_install":
write(morph(spec, spack.store.STORE.layout.root))
return
elif re.match(r"hash(:\d)?", attribute):
col = "#"
if ":" in attribute:
_, length = attribute.split(":")
write(sig + morph(spec, current.dag_hash(int(length))), col)
else:
write(sig + morph(spec, current.dag_hash()), col)
return
# Iterate over components using getattr to get next element # Iterate over components using getattr to get next element
for idx, part in enumerate(parts): for idx, part in enumerate(parts):
@ -4442,7 +4413,7 @@ def write_attribute(spec, attribute, color):
if part.startswith("_"): if part.startswith("_"):
raise SpecFormatStringError("Attempted to format private attribute") raise SpecFormatStringError("Attempted to format private attribute")
else: else:
if isinstance(current, vt.VariantMap): if part == "variants" and isinstance(current, vt.VariantMap):
# subscript instead of getattr for variant names # subscript instead of getattr for variant names
current = current[part] current = current[part]
else: else:
@ -4466,62 +4437,31 @@ def write_attribute(spec, attribute, color):
raise SpecFormatStringError(m) raise SpecFormatStringError(m)
if isinstance(current, vn.VersionList): if isinstance(current, vn.VersionList):
if current == vn.any_version: if current == vn.any_version:
# We don't print empty version lists # don't print empty version lists
return return ""
if callable(current): if callable(current):
raise SpecFormatStringError("Attempted to format callable object") raise SpecFormatStringError("Attempted to format callable object")
if current is None: if current is None:
# We're not printing anything # not printing anything
return return ""
# Set color codes for various attributes # Set color codes for various attributes
col = None color = None
if "variants" in parts: if "variants" in parts:
col = "+" color = VARIANT_COLOR
elif "architecture" in parts: elif "architecture" in parts:
col = "=" color = ARCHITECTURE_COLOR
elif "compiler" in parts or "compiler_flags" in parts: elif "compiler" in parts or "compiler_flags" in parts:
col = "%" color = COMPILER_COLOR
elif "version" in parts or "versions" in parts: elif "version" in parts or "versions" in parts:
col = "@" color = VERSION_COLOR
# Finally, write the output # return colored output
write(sig + morph(spec, str(current)), col) return safe_color(sig, str(current), color)
attribute = "" return SPEC_FORMAT_RE.sub(format_attribute, format_string).strip()
in_attribute = False
escape = False
for c in format_string:
if escape:
out.write(c)
escape = False
elif c == "\\":
escape = True
elif in_attribute:
if c == "}":
write_attribute(self, attribute, color)
attribute = ""
in_attribute = False
else:
attribute += c
else:
if c == "}":
raise SpecFormatStringError(
"Encountered closing } before opening { in %s" % format_string
)
elif c == "{":
in_attribute = True
else:
out.write(c)
if in_attribute:
raise SpecFormatStringError(
"Format string terminated while reading attribute." "Missing terminating }."
)
formatted_spec = out.getvalue()
return formatted_spec.strip()
def cformat(self, *args, **kwargs): def cformat(self, *args, **kwargs):
"""Same as format, but color defaults to auto instead of False.""" """Same as format, but color defaults to auto instead of False."""
@ -4529,6 +4469,16 @@ def cformat(self, *args, **kwargs):
kwargs.setdefault("color", None) kwargs.setdefault("color", None)
return self.format(*args, **kwargs) return self.format(*args, **kwargs)
@property
def spack_root(self):
"""Special field for using ``{spack_root}`` in Spec.format()."""
return spack.paths.spack_root
@property
def spack_install(self):
"""Special field for using ``{spack_install}`` in Spec.format()."""
return spack.store.STORE.layout.root
def format_path( def format_path(
# self, format_string: str, _path_ctor: Optional[pathlib.PurePath] = None # self, format_string: str, _path_ctor: Optional[pathlib.PurePath] = None
self, self,
@ -4554,14 +4504,21 @@ def format_path(
path_ctor = _path_ctor or pathlib.PurePath path_ctor = _path_ctor or pathlib.PurePath
format_string_as_path = path_ctor(format_string) format_string_as_path = path_ctor(format_string)
if format_string_as_path.is_absolute(): if format_string_as_path.is_absolute() or (
# Paths that begin with a single "\" on windows are relative, but we still
# want to preserve the initial "\\" to be consistent with PureWindowsPath.
# Ensure that this '\' is not passed to polite_filename() so it's not converted to '_'
(os.name == "nt" or path_ctor == pathlib.PureWindowsPath)
and format_string_as_path.parts[0] == "\\"
):
output_path_components = [format_string_as_path.parts[0]] output_path_components = [format_string_as_path.parts[0]]
input_path_components = list(format_string_as_path.parts[1:]) input_path_components = list(format_string_as_path.parts[1:])
else: else:
output_path_components = [] output_path_components = []
input_path_components = list(format_string_as_path.parts) input_path_components = list(format_string_as_path.parts)
output_path_components += [ output_path_components += [
fs.polite_filename(self.format(x)) for x in input_path_components fs.polite_filename(self.format(part)) for part in input_path_components
] ]
return str(path_ctor(*output_path_components)) return str(path_ctor(*output_path_components))

View File

@ -390,11 +390,11 @@ def test_built_spec_cache(mirror_dir):
assert any([r["spec"] == s for r in results]) assert any([r["spec"] == s for r in results])
def fake_dag_hash(spec): def fake_dag_hash(spec, length=None):
# Generate an arbitrary hash that is intended to be different than # Generate an arbitrary hash that is intended to be different than
# whatever a Spec reported before (to test actions that trigger when # whatever a Spec reported before (to test actions that trigger when
# the hash changes) # the hash changes)
return "tal4c7h4z0gqmixb1eqa92mjoybxn5l6" return "tal4c7h4z0gqmixb1eqa92mjoybxn5l6"[:length]
@pytest.mark.usefixtures( @pytest.mark.usefixtures(

View File

@ -1276,7 +1276,7 @@ def test_user_config_path_is_default_when_env_var_is_empty(working_env):
def test_default_install_tree(monkeypatch, default_config): def test_default_install_tree(monkeypatch, default_config):
s = spack.spec.Spec("nonexistent@x.y.z %none@a.b.c arch=foo-bar-baz") s = spack.spec.Spec("nonexistent@x.y.z %none@a.b.c arch=foo-bar-baz")
monkeypatch.setattr(s, "dag_hash", lambda: "abc123") monkeypatch.setattr(s, "dag_hash", lambda length: "abc123")
_, _, projections = spack.store.parse_install_tree(spack.config.get("config")) _, _, projections = spack.store.parse_install_tree(spack.config.get("config"))
assert s.format(projections["all"]) == "foo-bar-baz/none-a.b.c/nonexistent-x.y.z-abc123" assert s.format(projections["all"]) == "foo-bar-baz/none-a.b.c/nonexistent-x.y.z-abc123"

View File

@ -146,9 +146,6 @@ def test_autoload_all(self, modulefile_content, module_configuration):
assert len([x for x in content if "depends_on(" in x]) == 5 assert len([x for x in content if "depends_on(" in x]) == 5
@pytest.mark.skipif(
str(archspec.cpu.host().family) != "x86_64", reason="test data is specific for x86_64"
)
def test_alter_environment(self, modulefile_content, module_configuration): def test_alter_environment(self, modulefile_content, module_configuration):
"""Tests modifications to run-time environment.""" """Tests modifications to run-time environment."""

View File

@ -114,9 +114,6 @@ def test_prerequisites_all(
assert len([x for x in content if "prereq" in x]) == 5 assert len([x for x in content if "prereq" in x]) == 5
@pytest.mark.skipif(
str(archspec.cpu.host().family) != "x86_64", reason="test data is specific for x86_64"
)
def test_alter_environment(self, modulefile_content, module_configuration): def test_alter_environment(self, modulefile_content, module_configuration):
"""Tests modifications to run-time environment.""" """Tests modifications to run-time environment."""

View File

@ -703,22 +703,25 @@ def check_prop(check_spec, fmt_str, prop, getter):
actual = spec.format(named_str) actual = spec.format(named_str)
assert expected == actual assert expected == actual
def test_spec_formatting_escapes(self, default_mock_concretization): @pytest.mark.parametrize(
spec = default_mock_concretization("multivalue-variant cflags=-O2") "fmt_str",
[
sigil_mismatches = [
"{@name}", "{@name}",
"{@version.concrete}", "{@version.concrete}",
"{%compiler.version}", "{%compiler.version}",
"{/hashd}", "{/hashd}",
"{arch=architecture.os}", "{arch=architecture.os}",
] ],
)
def test_spec_formatting_sigil_mismatches(self, default_mock_concretization, fmt_str):
spec = default_mock_concretization("multivalue-variant cflags=-O2")
for fmt_str in sigil_mismatches: with pytest.raises(SpecFormatSigilError):
with pytest.raises(SpecFormatSigilError): spec.format(fmt_str)
spec.format(fmt_str)
bad_formats = [ @pytest.mark.parametrize(
"fmt_str",
[
r"{}", r"{}",
r"name}", r"name}",
r"\{name}", r"\{name}",
@ -728,11 +731,12 @@ def test_spec_formatting_escapes(self, default_mock_concretization):
r"{dag_hash}", r"{dag_hash}",
r"{foo}", r"{foo}",
r"{+variants.debug}", r"{+variants.debug}",
] ],
)
for fmt_str in bad_formats: def test_spec_formatting_bad_formats(self, default_mock_concretization, fmt_str):
with pytest.raises(SpecFormatStringError): spec = default_mock_concretization("multivalue-variant cflags=-O2")
spec.format(fmt_str) with pytest.raises(SpecFormatStringError):
spec.format(fmt_str)
def test_combination_of_wildcard_or_none(self): def test_combination_of_wildcard_or_none(self):
# Test that using 'none' and another value raises # Test that using 'none' and another value raises
@ -1138,12 +1142,12 @@ def _check_spec_format_path(spec_str, format_str, expected, path_ctor=None):
r"\\hostname\sharename\{name}\{version}", r"\\hostname\sharename\{name}\{version}",
r"\\hostname\sharename\git-test\git.foo_bar", r"\\hostname\sharename\git-test\git.foo_bar",
), ),
# Windows doesn't attribute any significance to a leading # leading '/' is preserved on windows but converted to '\'
# "/" so it is discarded # note that it's still not "absolute" -- absolute windows paths start with a drive.
( (
"git-test@git.foo/bar", "git-test@git.foo/bar",
r"/installroot/{name}/{version}", r"/installroot/{name}/{version}",
r"installroot\git-test\git.foo_bar", r"\installroot\git-test\git.foo_bar",
), ),
], ],
) )

View File

@ -638,6 +638,9 @@ def copy(self):
return clone return clone
def __str__(self): def __str__(self):
if not self:
return ""
# print keys in order # print keys in order
sorted_keys = sorted(self.keys()) sorted_keys = sorted(self.keys())

View File

@ -146,13 +146,11 @@ def from_string(string: str):
@staticmethod @staticmethod
def typemin(): def typemin():
return StandardVersion("", ((), (ALPHA,)), ("",)) return _STANDARD_VERSION_TYPEMIN
@staticmethod @staticmethod
def typemax(): def typemax():
return StandardVersion( return _STANDARD_VERSION_TYPEMAX
"infinity", ((VersionStrComponent(len(infinity_versions)),), (FINAL,)), ("",)
)
def __bool__(self): def __bool__(self):
return True return True
@ -390,6 +388,13 @@ def up_to(self, index):
return self[:index] return self[:index]
_STANDARD_VERSION_TYPEMIN = StandardVersion("", ((), (ALPHA,)), ("",))
_STANDARD_VERSION_TYPEMAX = StandardVersion(
"infinity", ((VersionStrComponent(len(infinity_versions)),), (FINAL,)), ("",)
)
class GitVersion(ConcreteVersion): class GitVersion(ConcreteVersion):
"""Class to represent versions interpreted from git refs. """Class to represent versions interpreted from git refs.
@ -1019,6 +1024,9 @@ def __hash__(self):
return hash(tuple(self.versions)) return hash(tuple(self.versions))
def __str__(self): def __str__(self):
if not self.versions:
return ""
return ",".join( return ",".join(
f"={v}" if isinstance(v, StandardVersion) else str(v) for v in self.versions f"={v}" if isinstance(v, StandardVersion) else str(v) for v in self.versions
) )
@ -1127,7 +1135,9 @@ def _prev_version(v: StandardVersion) -> StandardVersion:
components[1::2] = separators[: len(release)] components[1::2] = separators[: len(release)]
if prerelease_type != FINAL: if prerelease_type != FINAL:
components.extend((PRERELEASE_TO_STRING[prerelease_type], *prerelease[1:])) components.extend((PRERELEASE_TO_STRING[prerelease_type], *prerelease[1:]))
return StandardVersion("".join(str(c) for c in components), (release, prerelease), separators)
# this is only used for comparison functions, so don't bother making a string
return StandardVersion(None, (release, prerelease), separators)
def Version(string: Union[str, int]) -> Union[GitVersion, StandardVersion]: def Version(string: Union[str, int]) -> Union[GitVersion, StandardVersion]: