Compare commits

..

3 Commits

Author SHA1 Message Date
Todd Gamblin
690fad1182 WIP 2025-03-25 22:34:39 -07:00
Todd Gamblin
13446994ab WIP 2025-03-25 22:34:39 -07:00
Todd Gamblin
327462e8e2 info: generify when-grouping code
We want to show dependencies grouped by conditions, as we already do
with variants. This takes the first step and generifies the variant
display code.
2025-03-25 22:34:39 -07:00
17 changed files with 485 additions and 414 deletions

View File

@@ -4,6 +4,7 @@
import json
import os
import re
import shutil
import sys
from typing import Dict
@@ -25,10 +26,12 @@
import spack.hash_types as ht
import spack.mirrors.mirror
import spack.package_base
import spack.paths
import spack.repo
import spack.spec
import spack.stage
import spack.util.executable
import spack.util.git
import spack.util.gpg as gpg_util
import spack.util.timer as timer
import spack.util.url as url_util
@@ -42,6 +45,7 @@
SPACK_COMMAND = "spack"
INSTALL_FAIL_CODE = 1
FAILED_CREATE_BUILDCACHE_CODE = 100
BUILTIN = re.compile(r"var\/spack\/repos\/builtin\/packages\/([^\/]+)\/package\.py")
def deindent(desc):
@@ -779,15 +783,18 @@ def ci_verify_versions(args):
then parses the git diff between the two to determine which packages
have been modified verifies the new checksums inside of them.
"""
# Get a list of all packages that have been changed or added
# between from_ref and to_ref
pkgs = spack.repo.get_all_package_diffs("AC", args.from_ref, args.to_ref)
with fs.working_dir(spack.paths.prefix):
# We use HEAD^1 explicitly on the merge commit created by
# GitHub Actions. However HEAD~1 is a safer default for the helper function.
files = spack.util.git.get_modified_files(from_ref=args.from_ref, to_ref=args.to_ref)
# Get a list of package names from the modified files.
pkgs = [(m.group(1), p) for p in files for m in [BUILTIN.search(p)] if m]
failed_version = False
for pkg_name in pkgs:
for pkg_name, path in pkgs:
spec = spack.spec.Spec(pkg_name)
pkg = spack.repo.PATH.get_pkg_class(spec.name)(spec)
path = spack.repo.PATH.package_path(pkg_name)
# Skip checking manual download packages and trust the maintainers
if pkg.manual_download:
@@ -811,7 +818,7 @@ def ci_verify_versions(args):
# TODO: enforce every version have a commit or a sha256 defined if not
# an infinite version (there are a lot of package's where this doesn't work yet.)
with fs.working_dir(os.path.dirname(path)):
with fs.working_dir(spack.paths.prefix):
added_checksums = spack_ci.get_added_versions(
checksums_version_dict, path, from_ref=args.from_ref, to_ref=args.to_ref
)

View File

@@ -5,6 +5,7 @@
import sys
import textwrap
from itertools import zip_longest
from typing import Callable, Dict, TypeVar
import llnl.util.tty as tty
import llnl.util.tty.color as color
@@ -14,11 +15,12 @@
import spack.deptypes as dt
import spack.fetch_strategy as fs
import spack.install_test
import spack.package_base
import spack.repo
import spack.spec
import spack.variant
from spack.cmd.common import arguments
from spack.package_base import preferred_version
from spack.util.typing import SupportsRichComparison
description = "get detailed information on a particular package"
section = "basic"
@@ -28,6 +30,44 @@
plain_format = "@."
class Formatter:
"""Generic formatter for elements displayed by `spack info`.
Elements have four parts: name, values, when condition, and description. They can
be formatted two ways (shown here for variants)::
Grouped by when (default)::
when +cuda
cuda_arch [none] none, 10, 100, 100a, 101,
101a, 11, 12, 120, 120a, 13
CUDA architecture
Or, by name (each name has a when nested under it)::
cuda_arch [none] none, 10, 100, 100a, 101,
101a, 11, 12, 120, 120a, 13
when +cuda
CUDA architecture
The values and description will be wrapped if needed. the name (and any additional info)
will not (so they should be kept short).
Subclasses are responsible for generating colorized text, but not wrapping,
indentation, or other formatting, for the name, values, and description.
"""
def format_name(self, element) -> str:
return ""
def format_values(self, element) -> str:
return ""
def format_description(self, element) -> str:
return ""
def padder(str_list, extra=0):
"""Return a function to pad elements of a list."""
length = max(len(str(s)) for s in str_list) + extra
@@ -140,17 +180,19 @@ def lines(self):
yield " " + self.fmt % t
class DependencyFormatter(Formatter):
def format_name(self, dep) -> str:
return str(dep.spec)
def format_values(self, dep) -> str:
return str(dt.flag_to_tuple(dep.depflag))
def print_dependencies(pkg, args):
"""output build, link, and run package dependencies"""
for deptype in ("build", "link", "run"):
color.cprint("")
color.cprint(section_title("%s Dependencies:" % deptype.capitalize()))
deps = sorted(pkg.dependencies_of_type(dt.flag_from_string(deptype)))
if deps:
colify(deps, indent=4)
else:
color.cprint(" None")
print_fn = print_by_name if args.variants_by_name else print_grouped_by_when
print_fn("Dependencies", pkg.dependencies, DependencyFormatter())
def print_detectable(pkg, args):
@@ -263,66 +305,70 @@ def print_tests(pkg, args):
color.cprint(" None")
def _fmt_value(v):
if v is None or isinstance(v, bool):
return str(v).lower()
else:
return str(v)
def _fmt_name_and_default(variant):
"""Print colorized name [default] for a variant."""
return color.colorize(f"@c{{{variant.name}}} @C{{[{_fmt_value(variant.default)}]}}")
def _fmt_when(when: "spack.spec.Spec", indent: int):
return color.colorize(f"{indent * ' '}@B{{when}} {color.cescape(str(when))}")
def _fmt_variant_description(variant, width, indent):
"""Format a variant's description, preserving explicit line breaks."""
return "\n".join(
textwrap.fill(
line, width=width, initial_indent=indent * " ", subsequent_indent=indent * " "
def _fmt_variant_value(v):
return str(v).lower() if v is None or isinstance(v, bool) else str(v)
class VariantFormatter(Formatter):
def format_name(self, variant) -> str:
return color.colorize(
f"@c{{{variant.name}}} @C{{[{_fmt_variant_value(variant.default)}]}}"
)
for line in variant.description.split("\n")
)
def format_values(self, variant) -> str:
values = variant.values
if not isinstance(variant.values, (tuple, list, spack.variant.DisjointSetsOfValues)):
values = [variant.values]
# put 'none' first, sort the rest by value
sorted_values = sorted(values, key=lambda v: (v != "none", v))
return color.colorize(f"@c{{{', '.join(_fmt_variant_value(v) for v in sorted_values)}}}")
def format_description(self, variant) -> str:
return variant.description
def _fmt_variant(variant, max_name_default_len, indent, when=None, out=None):
def _fmt_definition(
name_field, values_field, description, max_name_len, indent, when=None, out=None
):
"""Format a definition entry in `spack info` output.
Arguments:
name_field: name and optional info, e.g. a default; should be short.
values_field: possible values for the entry; Wrapped if long.
description: description of the field (wrapped if overly long)
indent: size of leading indent for entry
when: optional when condition
out: stream to print to
"""
out = out or sys.stdout
_, cols = tty.terminal_size()
name_and_default = _fmt_name_and_default(variant)
name_default_len = color.clen(name_and_default)
name_len = color.clen(name_field)
values = variant.values
if not isinstance(variant.values, (tuple, list, spack.variant.DisjointSetsOfValues)):
values = [variant.values]
pad = 4 # min padding between name and values
value_indent = (indent + max_name_len + pad) * " " # left edge of values
# put 'none' first, sort the rest by value
sorted_values = sorted(values, key=lambda v: (v != "none", v))
pad = 4 # min padding between 'name [default]' and values
value_indent = (indent + max_name_default_len + pad) * " " # left edge of values
# This preserves any formatting (i.e., newlines) from how the description was
# written in package.py, but still wraps long lines for small terminals.
# This allows some packages to provide detailed help on their variants (see, e.g., gasnet).
formatted_values = "\n".join(
textwrap.wrap(
f"{', '.join(_fmt_value(v) for v in sorted_values)}",
width=cols - 2,
initial_indent=value_indent,
subsequent_indent=value_indent,
if values_field:
formatted_values = "\n".join(
textwrap.wrap(
values_field,
width=cols - 2,
initial_indent=value_indent,
subsequent_indent=value_indent,
)
)
)
formatted_values = formatted_values[indent + name_default_len + pad :]
# trim initial indentation
formatted_values = formatted_values[indent + name_len + pad :]
# name [default] value1, value2, value3, ...
padding = pad * " "
color.cprint(f"{indent * ' '}{name_and_default}{padding}@c{{{formatted_values}}}", stream=out)
# name [default] value1, value2, value3, ...
out.write(f"{indent * ' '}{name_field}{pad * ' '}{formatted_values}\n")
# when <spec>
description_indent = indent + 4
@@ -330,38 +376,65 @@ def _fmt_variant(variant, max_name_default_len, indent, when=None, out=None):
out.write(_fmt_when(when, description_indent - 2))
out.write("\n")
# description, preserving explicit line breaks from the way it's written in the package file
out.write(_fmt_variant_description(variant, cols - 2, description_indent))
out.write("\n")
# description, preserving explicit line breaks from the way it's written in the
# package file, but still wrapoing long lines for small terminals. This allows
# descriptions to provide detailed help in descriptions (see, e.g., gasnet's variants).
if description:
formatted_description = "\n".join(
textwrap.fill(
line,
width=cols - 2,
initial_indent=description_indent * " ",
subsequent_indent=description_indent * " ",
)
for line in description.split("\n")
)
out.write(formatted_description)
out.write("\n")
def _print_variants_header(pkg):
"""output variants"""
K = TypeVar("K", bound=SupportsRichComparison)
V = TypeVar("V")
if not pkg.variants:
print(" None")
return
def print_header(header: str, when_indexed_dictionary: Dict, formatter: Formatter):
color.cprint("")
color.cprint(section_title("Variants:"))
color.cprint(section_title(f"{header}:"))
# Calculate the max length of the "name [default]" part of the variant display
# This lets us know where to print variant values.
max_name_default_len = max(
color.clen(_fmt_name_and_default(variant))
for name in pkg.variant_names()
for _, variant in pkg.variant_definitions(name)
if not when_indexed_dictionary:
print(" None")
def max_name_length(when_indexed_dictionary: Dict, formatter: Formatter) -> int:
# Calculate the max length of the first field of the definition. Lets us know how
# much to pad other fields on the first line.
return max(
color.clen(formatter.format_name(definition))
for subkey in spack.package_base._subkeys(when_indexed_dictionary)
for _, definition in spack.package_base._definitions(when_indexed_dictionary, subkey)
)
return max_name_default_len
def print_grouped_by_when(header: str, when_indexed_dictionary: Dict, formatter: Formatter):
"""Generic method to print metadata grouped by when conditions."""
def print_variants_grouped_by_when(pkg):
max_name_default_len = _print_variants_header(pkg)
print_header(header, when_indexed_dictionary, formatter)
if not when_indexed_dictionary:
return
max_name_len = max_name_length(when_indexed_dictionary, formatter)
# Calculate the max length of the first field of the definition. Lets us know how
# much to pad other fields on the first line.
max_name_len = max(
color.clen(formatter.format_name(definition))
for subkey in spack.package_base._subkeys(when_indexed_dictionary)
for _, definition in spack.package_base._definitions(when_indexed_dictionary, subkey)
)
indent = 4
for when, variants_by_name in pkg.variant_items():
padded_values = max_name_default_len + 4
for when, by_name in when_indexed_dictionary.items():
padded_values = max_name_len + 4
start_indent = indent
if when != spack.spec.Spec():
@@ -373,27 +446,46 @@ def print_variants_grouped_by_when(pkg):
padded_values -= 2
start_indent += 2
for name, variant in sorted(variants_by_name.items()):
_fmt_variant(variant, padded_values, start_indent, None, out=sys.stdout)
for subkey, definition in sorted(by_name.items()):
_fmt_definition(
formatter.format_name(definition),
formatter.format_values(definition),
formatter.format_description(definition),
max_name_len,
start_indent,
when=None,
out=sys.stdout,
)
def print_variants_by_name(pkg):
max_name_default_len = _print_variants_header(pkg)
max_name_default_len += 4
def print_by_name(header: str, when_indexed_dictionary: Dict, formatter: Formatter):
print_header(header, when_indexed_dictionary, formatter)
if not when_indexed_dictionary:
return
max_name_len = max_name_length(when_indexed_dictionary, formatter)
max_name_len += 4
indent = 4
for name in pkg.variant_names():
for when, variant in pkg.variant_definitions(name):
_fmt_variant(variant, max_name_default_len, indent, when, out=sys.stdout)
for subkey in spack.package_base._subkeys(when_indexed_dictionary):
for when, definition in spack.package_base._definitions(when_indexed_dictionary, subkey):
_fmt_definition(
formatter.format_name(definition),
formatter.format_values(definition),
formatter.format_description(definition),
max_name_len,
indent,
when=when,
out=sys.stdout,
)
sys.stdout.write("\n")
def print_variants(pkg, args):
"""output variants"""
if args.variants_by_name:
print_variants_by_name(pkg)
else:
print_variants_grouped_by_when(pkg)
print_fn = print_by_name if args.variants_by_name else print_grouped_by_when
print_fn("Variants", pkg.variants, VariantFormatter())
def print_versions(pkg, args):
@@ -413,7 +505,7 @@ def print_versions(pkg, args):
else:
pad = padder(pkg.versions, 4)
preferred = preferred_version(pkg)
preferred = spack.package_base.preferred_version(pkg)
def get_url(version):
try:

View File

@@ -22,22 +22,9 @@
import textwrap
import time
import traceback
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing_extensions import Literal, final
from typing_extensions import Literal
import llnl.util.filesystem as fsys
import llnl.util.tty as tty
@@ -457,7 +444,7 @@ def _precedence(obj) -> int:
"""Get either a 'precedence' attribute or item from an object."""
precedence = getattr(obj, "precedence", None)
if precedence is None:
raise KeyError(f"Couldn't get precedence from {type(obj)}")
return 0 # raise KeyError(f"Couldn't get precedence from {type(obj)}")
return precedence
@@ -1394,75 +1381,6 @@ def command(self) -> spack.util.executable.Executable:
return spack.util.executable.Executable(path)
raise RuntimeError(f"Unable to locate {self.spec.name} command in {self.home.bin}")
def find_headers(
self, *, features: Sequence[str] = (), virtual: Optional[str] = None
) -> fsys.HeaderList:
"""Return the header list for this package based on the query. This method can be
overridden by individual packages to return package specific headers.
Args:
features: query argument to filter or extend the header list.
virtual: when set, return headers relevant for the virtual provided by this package.
Raises:
spack.error.NoHeadersError: if there was an error locating the headers.
"""
spec = self.spec
home = self.home
headers = fsys.find_headers("*", root=home.include, recursive=True)
if headers:
return headers
raise spack.error.NoHeadersError(f"Unable to locate {spec.name} headers in {home}")
def find_libs(
self, *, features: Sequence[str] = (), virtual: Optional[str] = None
) -> fsys.LibraryList:
"""Return the library list for this package based on the query. This method can be
overridden by individual packages to return package specific libraries.
Args:
features: query argument to filter or extend the library list.
virtual: when set, return libraries relevant for the virtual provided by this package.
Raises:
spack.error.NoLibrariesError: if there was an error locating the libraries.
"""
spec = self.spec
home = self.home
name = self.spec.name.replace("-", "?")
# Avoid double 'lib' for packages whose names already start with lib
if not name.startswith("lib") and not spec.satisfies("platform=windows"):
name = "lib" + name
# If '+shared' search only for shared library; if '~shared' search only for
# static library; otherwise, first search for shared and then for static.
search_shared = (
[True] if ("+shared" in spec) else ([False] if ("~shared" in spec) else [True, False])
)
for shared in search_shared:
# Since we are searching for link libraries, on Windows search only for
# ".Lib" extensions by default as those represent import libraries for implicit links.
libs = fsys.find_libraries(name, home, shared=shared, recursive=True, runtime=False)
if libs:
return libs
raise spack.error.NoLibrariesError(
f"Unable to recursively locate {spec.name} libraries in {home}"
)
@final
def query_headers(self, name: str, *, features: Sequence[str] = ()) -> fsys.HeaderList:
"""Returns the header list for a dependency ``name``."""
spec, is_virtual = self.spec._get_dependency_by_name(name)
return spec.package.find_headers(features=features, virtual=name if is_virtual else None)
@final
def query_libs(self, name: str, *, features: Sequence[str] = ()) -> fsys.LibraryList:
"""Returns the library list for a dependency ``name``."""
spec, is_virtual = self.spec._get_dependency_by_name(name)
return spec.package.find_libs(features=features, virtual=name if is_virtual else None)
def url_version(self, version):
"""
Given a version, this returns a string that should be substituted

View File

@@ -3005,10 +3005,6 @@ def setup(
# Fail if we already know an unreachable node is requested
for spec in specs:
# concrete roots don't need their dependencies verified
if spec.concrete:
continue
missing_deps = [
str(d)
for d in spec.traverse()

View File

@@ -663,9 +663,11 @@ def versions(self):
def display_str(self):
"""Equivalent to {compiler.name}{@compiler.version} for Specs, without extra
@= for readability."""
if self.versions != vn.any_version:
return self.spec.format("{name}{@version}")
return self.spec.format("{name}")
if self.spec.concrete:
return f"{self.name}@{self.version}"
elif self.versions != vn.any_version:
return f"{self.name}@{self.versions}"
return self.name
def __lt__(self, other):
if not isinstance(other, CompilerSpec):
@@ -1070,26 +1072,123 @@ def clear(self):
self.edges.clear()
def _headers_default_handler(spec: "Spec"):
"""Default handler when looking for the 'headers' attribute.
Tries to search for ``*.h`` files recursively starting from
``spec.package.home.include``.
Parameters:
spec: spec that is being queried
Returns:
HeaderList: The headers in ``prefix.include``
Raises:
NoHeadersError: If no headers are found
"""
home = getattr(spec.package, "home")
headers = fs.find_headers("*", root=home.include, recursive=True)
if headers:
return headers
raise spack.error.NoHeadersError(f"Unable to locate {spec.name} headers in {home}")
def _libs_default_handler(spec: "Spec"):
"""Default handler when looking for the 'libs' attribute.
Tries to search for ``lib{spec.name}`` recursively starting from
``spec.package.home``. If ``spec.name`` starts with ``lib``, searches for
``{spec.name}`` instead.
Parameters:
spec: spec that is being queried
Returns:
LibraryList: The libraries found
Raises:
NoLibrariesError: If no libraries are found
"""
# Variable 'name' is passed to function 'find_libraries', which supports
# glob characters. For example, we have a package with a name 'abc-abc'.
# Now, we don't know if the original name of the package is 'abc_abc'
# (and it generates a library 'libabc_abc.so') or 'abc-abc' (and it
# generates a library 'libabc-abc.so'). So, we tell the function
# 'find_libraries' to give us anything that matches 'libabc?abc' and it
# gives us either 'libabc-abc.so' or 'libabc_abc.so' (or an error)
# depending on which one exists (there is a possibility, of course, to
# get something like 'libabcXabc.so, but for now we consider this
# unlikely).
name = spec.name.replace("-", "?")
home = getattr(spec.package, "home")
# Avoid double 'lib' for packages whose names already start with lib
if not name.startswith("lib") and not spec.satisfies("platform=windows"):
name = "lib" + name
# If '+shared' search only for shared library; if '~shared' search only for
# static library; otherwise, first search for shared and then for static.
search_shared = (
[True] if ("+shared" in spec) else ([False] if ("~shared" in spec) else [True, False])
)
for shared in search_shared:
# Since we are searching for link libraries, on Windows search only for
# ".Lib" extensions by default as those represent import libraries for implicit links.
libs = fs.find_libraries(name, home, shared=shared, recursive=True, runtime=False)
if libs:
return libs
raise spack.error.NoLibrariesError(
f"Unable to recursively locate {spec.name} libraries in {home}"
)
class ForwardQueryToPackage:
"""Descriptor used to forward queries from Spec to Package"""
def __init__(self, attribute_name: str, _indirect: bool = False) -> None:
def __init__(
self,
attribute_name: str,
default_handler: Optional[Callable[["Spec"], Any]] = None,
_indirect: bool = False,
) -> None:
"""Create a new descriptor.
Parameters:
attribute_name: name of the attribute to be searched for in the Package instance
default_handler: default function to be called if the attribute was not found in the
Package instance
_indirect: temporarily added to redirect a query to another package.
"""
self.attribute_name = attribute_name
self.default = default_handler
self.indirect = _indirect
def __get__(self, instance: "SpecBuildInterface", cls):
"""Retrieves the property from Package using a well defined chain of responsibility.
"""Retrieves the property from Package using a well defined chain
of responsibility.
The call order is:
The order of call is:
1. `pkg.{virtual_name}_{attribute_name}` if the query is for a virtual package
2. `pkg.{attribute_name}` otherwise
1. if the query was through the name of a virtual package try to
search for the attribute `{virtual_name}_{attribute_name}`
in Package
2. try to search for attribute `{attribute_name}` in Package
3. try to call the default handler
The first call that produces a value will stop the chain.
If no call can handle the request then AttributeError is raised with a
message indicating that no relevant attribute exists.
If a call returns None, an AttributeError is raised with a message
indicating a query failure, e.g. that library files were not found in a
'libs' query.
"""
# TODO: this indirection exist solely for `spec["python"].command` to actually return
# spec["python-venv"].command. It should be removed when `python` is a virtual.
@@ -1105,36 +1204,61 @@ def __get__(self, instance: "SpecBuildInterface", cls):
_ = instance.wrapped_obj[instance.wrapped_obj.name] # NOQA: ignore=F841
query = instance.last_query
# First try the deprecated attributes (e.g. `<virtual>_libs` and `libs`)
callbacks_chain = []
# First in the chain : specialized attribute for virtual packages
if query.isvirtual:
deprecated_attrs = [f"{query.name}_{self.attribute_name}", self.attribute_name]
else:
deprecated_attrs = [self.attribute_name]
specialized_name = "{0}_{1}".format(query.name, self.attribute_name)
callbacks_chain.append(lambda: getattr(pkg, specialized_name))
# Try to get the generic method from Package
callbacks_chain.append(lambda: getattr(pkg, self.attribute_name))
# Final resort : default callback
if self.default is not None:
_default = self.default # make mypy happy
callbacks_chain.append(lambda: _default(instance.wrapped_obj))
for attr in deprecated_attrs:
if not hasattr(pkg, attr):
continue
value = getattr(pkg, attr)
# Deprecated properties can return None to indicate the query failed.
if value is None:
raise AttributeError(
f"Query of package '{pkg.name}' for '{self.attribute_name}' failed\n"
f"\tprefix : {instance.prefix}\n" # type: ignore[attr-defined]
f"\tspec : {instance}\n"
f"\tqueried as : {query.name}\n"
f"\textra parameters : {query.extra_parameters}"
)
return value
# Then try the new functions (e.g. `find_libs`).
features = query.extra_parameters
virtual = query.name if query.isvirtual else None
if self.attribute_name == "libs":
return pkg.find_libs(features=features, virtual=virtual)
elif self.attribute_name == "headers":
return pkg.find_headers(features=features, virtual=virtual)
raise AttributeError(f"Package {pkg.name} has no attribute {self.attribute_name}")
# Trigger the callbacks in order, the first one producing a
# value wins
value = None
message = None
for f in callbacks_chain:
try:
value = f()
# A callback can return None to trigger an error indicating
# that the query failed.
if value is None:
msg = "Query of package '{name}' for '{attrib}' failed\n"
msg += "\tprefix : {spec.prefix}\n"
msg += "\tspec : {spec}\n"
msg += "\tqueried as : {query.name}\n"
msg += "\textra parameters : {query.extra_parameters}"
message = msg.format(
name=pkg.name,
attrib=self.attribute_name,
spec=instance,
query=instance.last_query,
)
else:
return value
break
except AttributeError:
pass
# value is 'None'
if message is not None:
# Here we can use another type of exception. If we do that, the
# unit test 'test_getitem_exceptional_paths' in the file
# lib/spack/spack/test/spec_dag.py will need to be updated to match
# the type.
raise AttributeError(message)
# 'None' value at this point means that there are no appropriate
# properties defined and no default handler, or that all callbacks
# raised AttributeError. In this case, we raise AttributeError with an
# appropriate message.
fmt = "'{name}' package has no relevant attribute '{query}'\n"
fmt += "\tspec : '{spec}'\n"
fmt += "\tqueried as : '{spec.last_query.name}'\n"
fmt += "\textra parameters : '{spec.last_query.extra_parameters}'\n"
message = fmt.format(name=pkg.name, query=self.attribute_name, spec=instance)
raise AttributeError(message)
def __set__(self, instance, value):
cls_name = type(instance).__name__
@@ -1148,10 +1272,10 @@ def __set__(self, instance, value):
class SpecBuildInterface(lang.ObjectWrapper):
# home is available in the base Package so no default is needed
home = ForwardQueryToPackage("home")
headers = ForwardQueryToPackage("headers")
libs = ForwardQueryToPackage("libs")
command = ForwardQueryToPackage("command", _indirect=True)
home = ForwardQueryToPackage("home", default_handler=None)
headers = ForwardQueryToPackage("headers", default_handler=_headers_default_handler)
libs = ForwardQueryToPackage("libs", default_handler=_libs_default_handler)
command = ForwardQueryToPackage("command", default_handler=None, _indirect=True)
def __init__(
self,
@@ -3520,21 +3644,6 @@ def version(self):
raise spack.error.SpecError("Spec version is not concrete: " + str(self))
return self.versions[0]
def _get_dependency_by_name(self, name: str) -> Tuple["Spec", bool]:
"""Get a dependency by package name or virtual. Returns a tuple with the matching spec
and a boolean indicating if the spec is a virtual dependency. Raises a KeyError if the
dependency is not found."""
# Consider all direct dependencies and transitive runtime dependencies
order = itertools.chain(
self.edges_to_dependencies(depflag=dt.BUILD | dt.TEST),
self.traverse_edges(deptype=dt.LINK | dt.RUN, order="breadth", cover="edges"),
)
edge = next((e for e in order if e.spec.name == name or name in e.virtuals), None)
if edge is None:
raise KeyError(f"No spec with name {name} in {self}")
return edge.spec, name in edge.virtuals
def __getitem__(self, name: str):
"""Get a dependency from the spec by its name. This call implicitly
sets a query state in the package being retrieved. The behavior of
@@ -3555,14 +3664,23 @@ def __getitem__(self, name: str):
csv = query_parameters.pop().strip()
query_parameters = re.split(r"\s*,\s*", csv)
spec, is_virtual = self._get_dependency_by_name(name)
# Consider all direct dependencies and transitive runtime dependencies
order = itertools.chain(
self.edges_to_dependencies(depflag=dt.BUILD | dt.TEST),
self.traverse_edges(deptype=dt.LINK | dt.RUN, order="breadth", cover="edges"),
)
try:
edge = next((e for e in order if e.spec.name == name or name in e.virtuals))
except StopIteration as e:
raise KeyError(f"No spec with name {name} in {self}") from e
if self._concrete:
return SpecBuildInterface(
spec, name, query_parameters, _parent=self, is_virtual=is_virtual
edge.spec, name, query_parameters, _parent=self, is_virtual=name in edge.virtuals
)
return spec
return edge.spec
def __contains__(self, spec):
"""True if this spec or some dependency satisfies the spec.

View File

@@ -32,7 +32,7 @@ def repro_dir(tmp_path):
def test_get_added_versions_new_checksum(mock_git_package_changes):
repo, filename, commits = mock_git_package_changes
repo_path, filename, commits = mock_git_package_changes
checksum_versions = {
"3f6576971397b379d4205ae5451ff5a68edf6c103b2f03c4188ed7075fbb5f04": Version("2.1.5"),
@@ -41,7 +41,7 @@ def test_get_added_versions_new_checksum(mock_git_package_changes):
"86993903527d9b12fc543335c19c1d33a93797b3d4d37648b5addae83679ecd8": Version("2.0.0"),
}
with fs.working_dir(repo.packages_path):
with fs.working_dir(str(repo_path)):
added_versions = ci.get_added_versions(
checksum_versions, filename, from_ref=commits[-1], to_ref=commits[-2]
)
@@ -50,7 +50,7 @@ def test_get_added_versions_new_checksum(mock_git_package_changes):
def test_get_added_versions_new_commit(mock_git_package_changes):
repo, filename, commits = mock_git_package_changes
repo_path, filename, commits = mock_git_package_changes
checksum_versions = {
"74253725f884e2424a0dd8ae3f69896d5377f325": Version("2.1.6"),
@@ -60,9 +60,9 @@ def test_get_added_versions_new_commit(mock_git_package_changes):
"86993903527d9b12fc543335c19c1d33a93797b3d4d37648b5addae83679ecd8": Version("2.0.0"),
}
with fs.working_dir(repo.packages_path):
with fs.working_dir(str(repo_path)):
added_versions = ci.get_added_versions(
checksum_versions, filename, from_ref=commits[-2], to_ref=commits[-3]
checksum_versions, filename, from_ref=commits[2], to_ref=commits[1]
)
assert len(added_versions) == 1
assert added_versions[0] == Version("2.1.6")

View File

@@ -1978,13 +1978,6 @@ def test_ci_validate_git_versions_invalid(
assert f"Invalid commit for diff-test@{version}" in err
def mock_packages_path(path):
def packages_path():
return path
return packages_path
@pytest.fixture
def verify_standard_versions_valid(monkeypatch):
def validate_standard_versions(pkg, versions):
@@ -2031,12 +2024,9 @@ def test_ci_verify_versions_valid(
mock_git_package_changes,
verify_standard_versions_valid,
verify_git_versions_valid,
tmpdir,
):
repo, _, commits = mock_git_package_changes
spack.repo.PATH.put_first(repo)
monkeypatch.setattr(spack.repo, "packages_path", mock_packages_path(repo.packages_path))
repo_path, _, commits = mock_git_package_changes
monkeypatch.setattr(spack.paths, "prefix", repo_path)
out = ci_cmd("verify-versions", commits[-1], commits[-3])
assert "Validated diff-test@2.1.5" in out
@@ -2050,10 +2040,9 @@ def test_ci_verify_versions_standard_invalid(
verify_standard_versions_invalid,
verify_git_versions_invalid,
):
repo, _, commits = mock_git_package_changes
spack.repo.PATH.put_first(repo)
repo_path, _, commits = mock_git_package_changes
monkeypatch.setattr(spack.repo, "packages_path", mock_packages_path(repo.packages_path))
monkeypatch.setattr(spack.paths, "prefix", repo_path)
out = ci_cmd("verify-versions", commits[-1], commits[-3], fail_on_error=False)
assert "Invalid checksum found diff-test@2.1.5" in out
@@ -2061,10 +2050,8 @@ def test_ci_verify_versions_standard_invalid(
def test_ci_verify_versions_manual_package(monkeypatch, mock_packages, mock_git_package_changes):
repo, _, commits = mock_git_package_changes
spack.repo.PATH.put_first(repo)
monkeypatch.setattr(spack.repo, "packages_path", mock_packages_path(repo.packages_path))
repo_path, _, commits = mock_git_package_changes
monkeypatch.setattr(spack.paths, "prefix", repo_path)
pkg_class = spack.spec.Spec("diff-test").package_class
monkeypatch.setattr(pkg_class, "manual_download", True)

View File

@@ -243,11 +243,13 @@ def latest_commit():
@pytest.fixture
def mock_git_package_changes(git, tmpdir, override_git_repos_cache_path, monkeypatch):
def mock_git_package_changes(git, tmpdir, override_git_repos_cache_path):
"""Create a mock git repo with known structure of package edits
The structure of commits in this repo is as follows::
o diff-test: modification to make manual download package
|
o diff-test: add v1.2 (from a git ref)
|
o diff-test: add v1.1 (from source tarball)
@@ -259,12 +261,8 @@ def mock_git_package_changes(git, tmpdir, override_git_repos_cache_path, monkeyp
Important attributes of the repo for test coverage are: multiple package
versions are added with some coming from a tarball and some from git refs.
"""
filename = "diff-test/package.py"
repo_path, _ = spack.repo.create_repo(str(tmpdir.mkdir("myrepo")))
repo_cache = spack.util.file_cache.FileCache(str(tmpdir.mkdir("cache")))
repo = spack.repo.Repo(repo_path, cache=repo_cache)
repo_path = str(tmpdir.mkdir("git_package_changes_repo"))
filename = "var/spack/repos/builtin/packages/diff-test/package.py"
def commit(message):
global commit_counter
@@ -278,7 +276,7 @@ def commit(message):
)
commit_counter += 1
with working_dir(repo.packages_path):
with working_dir(repo_path):
git("init")
git("config", "user.name", "Spack")
@@ -309,11 +307,17 @@ def latest_commit():
commit("diff-test: add v2.1.6")
commits.append(latest_commit())
# convert pkg-a to a manual download package
shutil.copy2(f"{spack.paths.test_path}/data/conftest/diff-test/package-3.txt", filename)
git("add", filename)
commit("diff-test: modification to make manual download package")
commits.append(latest_commit())
# The commits are ordered with the last commit first in the list
commits = list(reversed(commits))
# Return the git directory to install, the filename used, and the commits
yield repo, filename, commits
yield repo_path, filename, commits
@pytest.fixture(autouse=True)

View File

@@ -0,0 +1,23 @@
# Copyright Spack Project Developers. See COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from spack.package import *
class DiffTest(AutotoolsPackage):
"""zlib replacement with optimizations for next generation systems."""
homepage = "https://github.com/zlib-ng/zlib-ng"
url = "https://github.com/zlib-ng/zlib-ng/archive/2.0.0.tar.gz"
git = "https://github.com/zlib-ng/zlib-ng.git"
license("Zlib")
manual_download = True
version("2.1.6", tag="2.1.6", commit="74253725f884e2424a0dd8ae3f69896d5377f325")
version("2.1.5", sha256="3f6576971397b379d4205ae5451ff5a68edf6c103b2f03c4188ed7075fbb5f04")
version("2.1.4", sha256="a0293475e6a44a3f6c045229fe50f69dc0eebc62a42405a51f19d46a5541e77a")
version("2.0.7", sha256="6c0853bb27738b811f2b4d4af095323c3d5ce36ceed6b50e5f773204fb8f7200")
version("2.0.0", sha256="86993903527d9b12fc543335c19c1d33a93797b3d4d37648b5addae83679ecd8")

View File

@@ -7,9 +7,9 @@
def test_modified_files(mock_git_package_changes):
repo, filename, commits = mock_git_package_changes
repo_path, filename, commits = mock_git_package_changes
with working_dir(repo.packages_path):
with working_dir(repo_path):
files = get_modified_files(from_ref="HEAD~1", to_ref="HEAD")
assert len(files) == 1
assert files[0] == filename

View File

@@ -27,8 +27,9 @@ spack:
- py-transformers
# JAX
- py-jax
- py-jaxlib
# Does not yet support Spack-installed ROCm
# - py-jax
# - py-jaxlib
# Keras
- py-keras backend=tensorflow

View File

@@ -18,7 +18,6 @@ class Asio(AutotoolsPackage):
license("BSL-1.0")
# As uneven minor versions of asio are not considered stable, they wont be added anymore
version("1.34.0", sha256="061ed6c8b97527756aed3e34d2cbcbcb6d3c80afd26ed6304f51119e1ef6a1cd")
version("1.32.0", sha256="f1b94b80eeb00bb63a3c8cef5047d4e409df4d8a3fe502305976965827d95672")
version("1.30.2", sha256="755bd7f85a4b269c67ae0ea254907c078d408cce8e1a352ad2ed664d233780e8")
version("1.30.1", sha256="94b121cc2016680f2314ef58eadf169c2d34fff97fba01df325a192d502d3a58")

View File

@@ -15,8 +15,6 @@ class Glab(GoPackage):
license("MIT")
version("1.55.0", sha256="21f58698b92035461e8e8ba9040429f4b5a0f6d528d8333834ef522a973384c8")
version("1.54.0", sha256="99f5dd785041ad26c8463ae8630e98a657aa542a2bb02333d50243dd5cfdf9cb")
version("1.53.0", sha256="2930aa5dd76030cc6edcc33483bb49dd6a328eb531d0685733ca7be7b906e915")
version("1.52.0", sha256="585495e53d3994172fb927218627b7470678bc766320cb52f4b4204238677dde")
version("1.51.0", sha256="6a95d827004fee258aacb49a427875e3b505b063cc578933d965cd56481f5a19")
@@ -36,38 +34,20 @@ class Glab(GoPackage):
version("1.21.1", sha256="8bb35c5cf6b011ff14d1eaa9ab70ec052d296978792984250e9063b006ee4d50")
version("1.20.0", sha256="6beb0186fa50d0dea3b05fcfe6e4bc1f9be0c07aa5fa15b37ca2047b16980412")
with default_args(type="build"):
depends_on("go@1.24.1:", when="@1.54:")
depends_on("go@1.23.4:", when="@1.52:")
depends_on("go@1.23.2:", when="@1.48:")
depends_on("go@1.23.0:", when="@1.46:")
depends_on("go@1.22.5:", when="@1.44:")
depends_on("go@1.22.4:", when="@1.42:")
depends_on("go@1.22.3:", when="@1.41:")
depends_on("go@1.21.0:", when="@1.37:")
depends_on("go@1.19.0:", when="@1.35:")
depends_on("go@1.18.0:", when="@1.23:")
depends_on("go@1.17.0:", when="@1.22:")
depends_on("go@1.13.0:")
depends_on("go@1.13:", type="build")
depends_on("go@1.17:", type="build", when="@1.22:")
depends_on("go@1.18:", type="build", when="@1.23:")
depends_on("go@1.19:", type="build", when="@1.35:")
depends_on("go@1.21:", type="build", when="@1.37:")
depends_on("go@1.22.3:", type="build", when="@1.41:")
depends_on("go@1.22.4:", type="build", when="@1.42:")
depends_on("go@1.22.5:", type="build", when="@1.44:")
depends_on("go@1.23:", type="build", when="@1.46:")
depends_on("go@1.23.2:", type="build", when="@1.48:")
depends_on("go@1.23.4:", type="build", when="@1.52:")
build_directory = "cmd/glab"
# Required to correctly set the version
# https://gitlab.com/gitlab-org/cli/-/blob/v1.55.0/Makefile?ref_type=tags#L44
@property
def build_args(self):
extra_ldflags = [f"-X 'main.version=v{self.version}'"]
args = super().build_args
if "-ldflags" in args:
ldflags_index = args.index("-ldflags") + 1
args[ldflags_index] = args[ldflags_index] + " " + " ".join(extra_ldflags)
else:
args.extend(["-ldflags", " ".join(extra_ldflags)])
return args
@run_after("install")
def install_completions(self):
glab = Executable(self.prefix.bin.glab)

View File

@@ -14,14 +14,12 @@ class GtkDoc(AutotoolsPackage):
pdf/man-pages with some extra work."""
homepage = "https://wiki.gnome.org/DocumentationProject/GtkDoc"
url = "https://download.gnome.org/sources/gtk-doc/1.33/gtk-doc-1.33.2.tar.xz"
list_url = "https://download.gnome.org/sources/gtk-doc/"
list_depth = 1
url = "https://gitlab.gnome.org/GNOME/gtk-doc/-/archive/1.33.2/gtk-doc-1.33.2.tar.gz"
license("GPL-2.0-or-later AND GFDL-1.1-or-later")
version("1.33.2", sha256="cc1b709a20eb030a278a1f9842a362e00402b7f834ae1df4c1998a723152bf43")
version("1.32", sha256="de0ef034fb17cb21ab0c635ec730d19746bce52984a6706e7bbec6fb5e0b907c")
version("1.33.2", sha256="2d1b0cbd26edfcb54694b2339106a02a81d630a7dedc357461aeb186874cc7c0")
version("1.32", sha256="0890c1f00d4817279be51602e67c4805daf264092adc58f9c04338566e8225ba")
depends_on("c", type="build") # generated
@@ -62,8 +60,14 @@ def installcheck(self):
pass
def url_for_version(self, version):
url = "https://download.gnome.org/sources/gtk-doc/{0}/gtk-doc-{1}.tar.xz"
return url.format(version.up_to(2), version)
"""Handle gnome's version-based custom URLs."""
if version <= Version("1.32"):
url = "https://gitlab.gnome.org/GNOME/gtk-doc/-/archive/GTK_DOC_{0}/gtk-doc-GTK_DOC_{0}.tar.gz"
return url.format(version.underscored)
url = "https://gitlab.gnome.org/GNOME/gtk-doc/-/archive/{0}/gtk-doc-{0}.tar.gz"
return url.format(version)
def configure_args(self):
args = ["--with-xml-catalog={0}".format(self["docbook-xml"].catalog)]

View File

@@ -169,10 +169,6 @@ class Hpx(CMakePackage, CudaPackage, ROCmPackage):
# Patches and one-off conflicts
# Asio 1.34.0 removed io_context::work, used by HPX:
# https://github.com/chriskohlhoff/asio/commit/a70f2df321ff40c1809773c2c09986745abf8d20.
conflicts("^asio@1.34:", when="@:1.10")
# Certain Asio headers don't compile with nvcc from 1.17.0 onwards with
# C++17. Starting with CUDA 11.3 they compile again.
conflicts("^asio@1.17.0:", when="+cuda cxxstd=17 ^cuda@:11.2")

View File

@@ -20,13 +20,13 @@ class PyJax(PythonPackage):
maintainers("adamjstewart", "jonas-eschle")
# version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8")
version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
# version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
# version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
# version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
# version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
# version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
# version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
# version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
@@ -85,13 +85,13 @@ class PyJax(PythonPackage):
# https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
for v in [
# "0.5.0",
"0.4.38",
"0.4.37",
"0.4.36",
"0.4.35",
"0.4.34",
"0.4.33",
"0.4.32",
# "0.4.38",
# "0.4.37",
# "0.4.36",
# "0.4.35",
# "0.4.34",
# "0.4.33",
# "0.4.32",
"0.4.31",
"0.4.30",
"0.4.29",
@@ -126,12 +126,12 @@ class PyJax(PythonPackage):
# See _minimum_jaxlib_version in jax/version.py
# depends_on("py-jaxlib@0.5:", when="@0.5:")
depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
# depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
# depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
# depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
# depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
# depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
# depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
depends_on("py-jaxlib@0.4.30:", when="@0.4.31:")
depends_on("py-jaxlib@0.4.27:", when="@0.4.28:")
depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")

View File

@@ -8,27 +8,20 @@
from spack.package import *
rocm_dependencies = [
"comgr",
"hip",
"hipblas",
"hipblaslt",
"hipcub",
"hipfft",
"hiprand",
"hipsolver",
"hipsparse",
"hsa-rocr-dev",
"miopen-hip",
"hip",
"rccl",
"rocblas",
"rocfft",
"rocminfo",
"rocprim",
"rocrand",
"rocsolver",
"rocsparse",
"hipcub",
"rocthrust",
"roctracer-dev",
"rocm-core",
"rocrand",
"hipsparse",
"hipfft",
"rocfft",
"rocblas",
"miopen-hip",
"rocminfo",
]
@@ -46,17 +39,14 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
license("Apache-2.0")
maintainers("adamjstewart", "jonas-eschle")
# version("0.5.3", sha256="1094581a30ec069965f4e3e67d60262570cc3dd016adc62073bc24347b14270c")
# version("0.5.2", sha256="8e9de1e012dd65fc4a9eec8af4aa2bf6782767130a5d8e1c1e342b7d658280fe")
# version("0.5.1", sha256="e74b1209517682075933f757d646b73040d09fe39ee3e9e4cd398407dd0902d2")
# version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9")
version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
# version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
# version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
# version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
# version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
# version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
# version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
# version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
@@ -103,10 +93,6 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
for pkg_dep in rocm_dependencies:
depends_on(f"{pkg_dep}@6:", when="@0.4.28:")
depends_on(pkg_dep)
depends_on("rocprofiler-register", when="^hip@6.2:")
depends_on("hipblas-common", when="^hip@6.3:")
depends_on("hsakmt-roct", when="^hip@:6.2")
depends_on("llvm-amdgpu")
depends_on("py-nanobind")
with default_args(type="build"):
@@ -127,7 +113,6 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
depends_on("python@3.9:", when="@0.4.14:")
depends_on("python@3.8:", when="@0.4.6:")
depends_on("python@:3.13")
depends_on("python@:3.12", when="+rocm")
depends_on("python@:3.12", when="@:0.4.33")
depends_on("python@:3.11", when="@:0.4.16")
@@ -182,22 +167,6 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
# Fails to build with freshly released CUDA (#48708).
conflicts("^cuda@12.8:", when="@:0.4.31")
# external CUDA is not supported https://github.com/jax-ml/jax/issues/23689
conflicts("+cuda", when="@0.4.32:")
# aarch64 is not supported https://github.com/jax-ml/jax/issues/25598
conflicts("target=aarch64:", when="@0.4.32:")
resource(
name="xla",
url="https://github.com/ROCm/xla/archive/07543ab117699a57c1267b453a62f89b1d5953fd.tar.gz",
sha256="cee377479654201c61cc3f230d89603cd589525fea2faf44564a23c70ba1448d",
expand=True,
destination="",
placement="xla",
when="@0.4.38:0.5.2 +rocm",
)
def url_for_version(self, version):
url = "https://github.com/jax-ml/jax/archive/refs/tags/{}-v{}.tar.gz"
if version >= Version("0.4.33"):
@@ -206,20 +175,6 @@ def url_for_version(self, version):
name = "jaxlib"
return url.format(name, version)
def setup_build_environment(self, env):
spec = self.spec
if spec.satisfies("@0.4.38: +rocm") and not spec["hip"].external:
if spec.satisfies("^hip@6.2:"):
rocm_dependencies.append("rocprofiler-register")
if spec.satisfies("^hip@6.3:"):
rocm_dependencies.append("hipblas-common")
else:
rocm_dependencies.append("hsakmt-roct")
env.set("LLVM_PATH", spec["llvm-amdgpu"].prefix)
for pkg_dep in rocm_dependencies:
env.prepend_path("TF_ROCM_MULTIPLE_PATHS", spec[pkg_dep].prefix)
env.prune_duplicate_paths("TF_ROCM_MULTIPLE_PATHS")
def install(self, spec, prefix):
# https://jax.readthedocs.io/en/latest/developer.html
args = ["build/build.py"]
@@ -261,15 +216,7 @@ def install(self, spec, prefix):
args.append(f"--bazel_options=--repo_env=LOCAL_NCCL_PATH={spec['nccl'].prefix}")
if "+rocm" in spec:
args.append(f"--rocm_path={self.spec['hip'].prefix}")
if spec.satisfies("@:0.4.35"):
args.append("--enable_rocm")
if spec.satisfies("@0.4.38:") and not spec["hip"].external:
args.append("--bazel_options=--@local_config_rocm//rocm:rocm_path_type=multiple")
if spec.satisfies("@0.4.38:0.5.2"):
args.append(
f"--bazel_options=--override_repository=xla={self.stage.source_path}/xla"
)
args.extend(["--enable_rocm", f"--rocm_path={self.spec['hip'].prefix}"])
args.extend(
[
@@ -280,6 +227,5 @@ def install(self, spec, prefix):
)
python(*args)
for whl in glob.glob(join_path("dist", "*.whl")):
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)
whl = glob.glob(join_path("dist", "*.whl"))[0]
pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)