diff --git a/lib/spack/spack/cmd/__init__.py b/lib/spack/spack/cmd/__init__.py index 030a95c2388..8a2351b412f 100644 --- a/lib/spack/spack/cmd/__init__.py +++ b/lib/spack/spack/cmd/__init__.py @@ -9,7 +9,7 @@ import re import sys from collections import Counter -from typing import List, Optional, Union +from typing import Generator, List, Optional, Sequence, Union import llnl.string import llnl.util.tty as tty @@ -704,6 +704,57 @@ def first_line(docstring): return docstring.split("\n")[0] +def group_arguments( + args: Sequence[str], max_group_size: int = 500, max_group_len: Optional[int] = None +) -> Generator[List[str], None, None]: + """Splits the supplied list of arguments into groups for passing to CLI tools. + + When passing CLI arguments, we need to ensure that argument lists are no longer than + the system command line size limit, and we may also need to ensure that groups are + no more than some number of arguments long. + + This returns an iterator over lists of arguments that meet these constraints. + Arguments are in the same order they appeared in the original argument list. + + If any argument's length is greater than the max_group_length, this will raise a + ``ValueError``. + + Arguments: + args: list of arguments to split into groups + max_group_size: max number of elements in any group (default 500) + max_group_len: max length of characters that if a group of args is joined by " " + On unix, ths defaults to SC_ARG_MAX from sysconf. On Windows the default is + the max usable for CreateProcess (32,768 chars) + + """ + if max_group_len is None: + max_group_len = 32768 # default to the Windows limit + if hasattr(os, "sysconf"): + # sysconf is only on unix and returns -1 if an option isn't present + sysconf_max = os.sysconf("SC_ARG_MAX") + if sysconf_max != -1: + max_group_len = sysconf_max + + group: List[str] = [] + grouplen, space = 0, 0 + for i, arg in enumerate(args): + arglen = len(arg) + if arglen > max_group_len: + raise ValueError(f"Argument is longer than the maximum command line size: '{arg}'") + + next_grouplen = grouplen + arglen + space + if len(group) == max_group_size or next_grouplen > max_group_len: + yield group + group, grouplen, space = [], 0, 0 + + group.append(arg) + grouplen += arglen + space + space = 1 # add a space for elements 1, 2, etc. but not 0 + + if group: + yield group + + class CommandNotFoundError(spack.error.SpackError): """Exception class thrown when a requested command is not recognized as such. diff --git a/lib/spack/spack/cmd/pkg.py b/lib/spack/spack/cmd/pkg.py index e4dffc744fa..70387563fd0 100644 --- a/lib/spack/spack/cmd/pkg.py +++ b/lib/spack/spack/cmd/pkg.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: (Apache-2.0 OR MIT) import argparse -import itertools import os import sys @@ -182,21 +181,19 @@ def pkg_grep(args, unknown_args): if "GNU" in grep("--version", output=str): grep.add_default_arg("--color=auto") - # determines number of files to grep at a time - grouper = lambda e: e[0] // 100 + all_paths = spack.repo.PATH.all_package_paths() + if not all_paths: + return 0 # no packages to search # set up iterator and save the first group to ensure we don't end up with a group of size 1 - groups = itertools.groupby(enumerate(spack.repo.PATH.all_package_paths()), grouper) - if not groups: - return 0 # no packages to search + groups = spack.cmd.group_arguments(all_paths) # You can force GNU grep to show filenames on every line with -H, but not POSIX grep. # POSIX grep only shows filenames when you're grepping 2 or more files. Since we # don't know which one we're running, we ensure there are always >= 2 files by # saving the prior group of paths and adding it to a straggling group of 1 if needed. # This works unless somehow there is only one package in all of Spack. - _, first_group = next(groups) - prior_paths = [path for _, path in first_group] + prior_paths = next(groups) # grep returns 1 for nothing found, 0 for something found, and > 1 for error return_code = 1 @@ -207,9 +204,7 @@ def grep_group(paths): grep(*all_args, fail_on_error=False) return grep.returncode - for _, group in groups: - paths = [path for _, path in group] # extract current path group - + for paths in groups: if len(paths) == 1: # Only the very last group can have length 1. If it does, combine # it with the prior group to ensure more than one path is grepped. diff --git a/lib/spack/spack/test/cmd/pkg.py b/lib/spack/spack/test/cmd/pkg.py index 0fe2153872b..53d587603bb 100644 --- a/lib/spack/spack/test/cmd/pkg.py +++ b/lib/spack/spack/test/cmd/pkg.py @@ -307,10 +307,56 @@ def test_pkg_hash(mock_packages): assert len(output) == 1 and all(len(elt) == 32 for elt in output) +group_args = [ + "/path/one.py", # 12 + "/path/two.py", # 12 + "/path/three.py", # 14 + "/path/four.py", # 13 + "/path/five.py", # 13 + "/path/six.py", # 12 + "/path/seven.py", # 14 + "/path/eight.py", # 14 + "/path/nine.py", # 13 + "/path/ten.py", # 12 +] + + +@pytest.mark.parametrize( + ["max_group_size", "max_group_len", "lengths", "error"], + [ + (3, 1, None, ValueError), + (3, 13, None, ValueError), + (3, 25, [2, 1, 1, 1, 1, 1, 1, 1, 1], None), + (3, 26, [2, 1, 1, 2, 1, 1, 2], None), + (3, 40, [3, 3, 2, 2], None), + (3, 43, [3, 3, 3, 1], None), + (4, 54, [4, 3, 3], None), + (4, 56, [4, 4, 2], None), + ], +) +def test_group_arguments(mock_packages, max_group_size, max_group_len, lengths, error): + generator = spack.cmd.group_arguments( + group_args, max_group_size=max_group_size, max_group_len=max_group_len + ) + + # just check that error cases raise + if error: + with pytest.raises(ValueError): + list(generator) + return + + groups = list(generator) + assert sum(groups, []) == group_args + assert [len(group) for group in groups] == lengths + assert all( + sum(len(elt) for elt in group) + (len(group) - 1) <= max_group_len for group in groups + ) + + @pytest.mark.skipif(not spack.cmd.pkg.get_grep(), reason="grep is not installed") def test_pkg_grep(mock_packages, capfd): # only splice-* mock packages have the string "splice" in them - pkg("grep", "-l", "splice", output=str) + pkg("grep", "-l", "splice") output, _ = capfd.readouterr() assert output.strip() == "\n".join( spack.repo.PATH.get_pkg_class(name).module.__file__ @@ -330,12 +376,14 @@ def test_pkg_grep(mock_packages, capfd): ] ) - # ensure that this string isn't fouhnd - pkg("grep", "abcdefghijklmnopqrstuvwxyz", output=str, fail_on_error=False) + # ensure that this string isn't found + with pytest.raises(spack.main.SpackCommandError): + pkg("grep", "abcdefghijklmnopqrstuvwxyz") assert pkg.returncode == 1 output, _ = capfd.readouterr() assert output.strip() == "" # ensure that we return > 1 for an error - pkg("grep", "--foobarbaz-not-an-option", output=str, fail_on_error=False) + with pytest.raises(spack.main.SpackCommandError): + pkg("grep", "--foobarbaz-not-an-option") assert pkg.returncode == 2