Use process pool executors for web-crawling and retrieving archives (#39888)

Fix a race condition when searching urls, and updating a shared
set '_visited'.
This commit is contained in:
Massimiliano Culpo
2023-09-19 15:32:59 +02:00
committed by GitHub
parent 7d33c36a30
commit 3b4ca0374e
20 changed files with 300 additions and 288 deletions

View File

@@ -34,6 +34,7 @@
import spack.cmd
import spack.config as config
import spack.database as spack_db
import spack.error
import spack.hooks
import spack.hooks.sbang
import spack.mirror
@@ -1417,7 +1418,7 @@ def try_fetch(url_to_fetch):
try:
stage.fetch()
except web_util.FetchError:
except spack.error.FetchError:
stage.destroy()
return None
@@ -2144,7 +2145,7 @@ def get_keys(install=False, trust=False, force=False, mirrors=None):
if not os.path.exists(stage.save_filename):
try:
stage.fetch()
except web_util.FetchError:
except spack.error.FetchError:
continue
tty.debug("Found key {0}".format(fingerprint))
@@ -2296,7 +2297,7 @@ def _download_buildcache_entry(mirror_root, descriptions):
try:
stage.fetch()
break
except web_util.FetchError as e:
except spack.error.FetchError as e:
tty.debug(e)
else:
if fail_if_missing:

View File

@@ -527,7 +527,7 @@ def copy_buildcache_file(src_url, dest_url, local_path=None):
temp_stage.create()
temp_stage.fetch()
web_util.push_to_url(local_path, dest_url, keep_original=True)
except web_util.FetchError as e:
except spack.error.FetchError as e:
# Expected, since we have to try all the possible extensions
tty.debug("no such file: {0}".format(src_url))
tty.debug(e)

View File

@@ -66,7 +66,7 @@ def setup_parser(subparser):
modes_parser.add_argument(
"--verify", action="store_true", default=False, help="verify known package checksums"
)
arguments.add_common_arguments(subparser, ["package"])
arguments.add_common_arguments(subparser, ["package", "jobs"])
subparser.add_argument(
"versions", nargs=argparse.REMAINDER, help="versions to generate checksums for"
)
@@ -96,7 +96,7 @@ def checksum(parser, args):
# Add latest version if requested
if args.latest:
remote_versions = pkg.fetch_remote_versions()
remote_versions = pkg.fetch_remote_versions(args.jobs)
if len(remote_versions) > 0:
latest_version = sorted(remote_versions.keys(), reverse=True)[0]
versions.append(latest_version)
@@ -119,13 +119,13 @@ def checksum(parser, args):
# if we get here, it's because no valid url was provided by the package
# do expensive fallback to try to recover
if remote_versions is None:
remote_versions = pkg.fetch_remote_versions()
remote_versions = pkg.fetch_remote_versions(args.jobs)
if version in remote_versions:
url_dict[version] = remote_versions[version]
if len(versions) <= 0:
if remote_versions is None:
remote_versions = pkg.fetch_remote_versions()
remote_versions = pkg.fetch_remote_versions(args.jobs)
url_dict = remote_versions
if not url_dict:

View File

@@ -37,10 +37,7 @@ def setup_parser(subparser):
action="store_true",
help="only list remote versions newer than the latest checksummed version",
)
subparser.add_argument(
"-c", "--concurrency", default=32, type=int, help="number of concurrent requests"
)
arguments.add_common_arguments(subparser, ["package"])
arguments.add_common_arguments(subparser, ["package", "jobs"])
def versions(parser, args):
@@ -68,7 +65,7 @@ def versions(parser, args):
if args.safe:
return
fetched_versions = pkg.fetch_remote_versions(args.concurrency)
fetched_versions = pkg.fetch_remote_versions(args.jobs)
if args.new:
if sys.stdout.isatty():

View File

@@ -128,3 +128,7 @@ def __init__(self, provided, required, constraint_type):
self.provided = provided
self.required = required
self.constraint_type = constraint_type
class FetchError(SpackError):
"""Superclass for fetch-related errors."""

View File

@@ -401,7 +401,7 @@ def _fetch_curl(self, url):
try:
web_util.check_curl_code(curl.returncode)
except web_util.FetchError as err:
except spack.error.FetchError as err:
raise spack.fetch_strategy.FailedDownloadError(url, str(err))
self._check_headers(headers)
@@ -1290,7 +1290,7 @@ def fetch(self):
parsed_url = urllib.parse.urlparse(self.url)
if parsed_url.scheme != "s3":
raise web_util.FetchError("S3FetchStrategy can only fetch from s3:// urls.")
raise spack.error.FetchError("S3FetchStrategy can only fetch from s3:// urls.")
tty.debug("Fetching {0}".format(self.url))
@@ -1337,7 +1337,7 @@ def fetch(self):
parsed_url = urllib.parse.urlparse(self.url)
if parsed_url.scheme != "gs":
raise web_util.FetchError("GCSFetchStrategy can only fetch from gs:// urls.")
raise spack.error.FetchError("GCSFetchStrategy can only fetch from gs:// urls.")
tty.debug("Fetching {0}".format(self.url))
@@ -1431,7 +1431,7 @@ def from_kwargs(**kwargs):
on attribute names (e.g., ``git``, ``hg``, etc.)
Raises:
spack.util.web.FetchError: If no ``fetch_strategy`` matches the args.
spack.error.FetchError: If no ``fetch_strategy`` matches the args.
"""
for fetcher in all_strategies:
if fetcher.matches(kwargs):
@@ -1538,7 +1538,7 @@ def for_package_version(pkg, version=None):
# if it's a commit, we must use a GitFetchStrategy
if isinstance(version, spack.version.GitVersion):
if not hasattr(pkg, "git"):
raise web_util.FetchError(
raise spack.error.FetchError(
f"Cannot fetch git version for {pkg.name}. Package has no 'git' attribute"
)
# Populate the version with comparisons to other commits
@@ -1688,11 +1688,11 @@ def destroy(self):
shutil.rmtree(self.root, ignore_errors=True)
class NoCacheError(web_util.FetchError):
class NoCacheError(spack.error.FetchError):
"""Raised when there is no cached archive for a package."""
class FailedDownloadError(web_util.FetchError):
class FailedDownloadError(spack.error.FetchError):
"""Raised when a download fails."""
def __init__(self, url, msg=""):
@@ -1700,23 +1700,23 @@ def __init__(self, url, msg=""):
self.url = url
class NoArchiveFileError(web_util.FetchError):
class NoArchiveFileError(spack.error.FetchError):
"""Raised when an archive file is expected but none exists."""
class NoDigestError(web_util.FetchError):
class NoDigestError(spack.error.FetchError):
"""Raised after attempt to checksum when URL has no digest."""
class ExtrapolationError(web_util.FetchError):
class ExtrapolationError(spack.error.FetchError):
"""Raised when we can't extrapolate a version for a package."""
class FetcherConflict(web_util.FetchError):
class FetcherConflict(spack.error.FetchError):
"""Raised for packages with invalid fetch attributes."""
class InvalidArgsError(web_util.FetchError):
class InvalidArgsError(spack.error.FetchError):
"""Raised when a version can't be deduced from a set of arguments."""
def __init__(self, pkg=None, version=None, **args):
@@ -1729,11 +1729,11 @@ def __init__(self, pkg=None, version=None, **args):
super().__init__(msg, long_msg)
class ChecksumError(web_util.FetchError):
class ChecksumError(spack.error.FetchError):
"""Raised when archive fails to checksum."""
class NoStageError(web_util.FetchError):
class NoStageError(spack.error.FetchError):
"""Raised when fetch operations are called before set_stage()."""
def __init__(self, method):

View File

@@ -66,7 +66,6 @@
from spack.stage import DIYStage, ResourceStage, Stage, StageComposite, compute_stage_name
from spack.util.executable import ProcessError, which
from spack.util.package_hash import package_hash
from spack.util.web import FetchError
from spack.version import GitVersion, StandardVersion, Version
FLAG_HANDLER_RETURN_TYPE = Tuple[
@@ -1394,7 +1393,7 @@ def do_fetch(self, mirror_only=False):
tty.debug("Fetching with no checksum. {0}".format(ck_msg))
if not ignore_checksum:
raise FetchError(
raise spack.error.FetchError(
"Will not fetch %s" % self.spec.format("{name}{@version}"), ck_msg
)
@@ -1420,7 +1419,7 @@ def do_fetch(self, mirror_only=False):
tty.debug("Fetching deprecated version. {0}".format(dp_msg))
if not ignore_deprecation:
raise FetchError(
raise spack.error.FetchError(
"Will not fetch {0}".format(self.spec.format("{name}{@version}")), dp_msg
)
@@ -1447,7 +1446,7 @@ def do_stage(self, mirror_only=False):
self.stage.expand_archive()
if not os.listdir(self.stage.path):
raise FetchError("Archive was empty for %s" % self.name)
raise spack.error.FetchError("Archive was empty for %s" % self.name)
else:
# Support for post-install hooks requires a stage.source_path
fsys.mkdirp(self.stage.source_path)
@@ -2365,7 +2364,7 @@ def all_urls(self):
urls.append(args["url"])
return urls
def fetch_remote_versions(self, concurrency=128):
def fetch_remote_versions(self, concurrency=None):
"""Find remote versions of this package.
Uses ``list_url`` and any other URLs listed in the package file.

View File

@@ -76,7 +76,7 @@ def __init__(self, pkg, path_or_url, level, working_dir):
self.level = level
self.working_dir = working_dir
def apply(self, stage: spack.stage.Stage):
def apply(self, stage: "spack.stage.Stage"):
"""Apply a patch to source in a stage.
Arguments:
@@ -190,7 +190,7 @@ def __init__(self, pkg, url, level=1, working_dir=".", ordering_key=None, **kwar
if not self.sha256:
raise PatchDirectiveError("URL patches require a sha256 checksum")
def apply(self, stage: spack.stage.Stage):
def apply(self, stage: "spack.stage.Stage"):
assert self.stage.expanded, "Stage must be expanded before applying patches"
# Get the patch file.

View File

@@ -2,7 +2,7 @@
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import concurrent.futures
import errno
import getpass
import glob
@@ -12,7 +12,7 @@
import stat
import sys
import tempfile
from typing import Dict, Iterable
from typing import Callable, Dict, Iterable, Optional
import llnl.util.lang
import llnl.util.tty as tty
@@ -37,9 +37,9 @@
import spack.util.lock
import spack.util.path as sup
import spack.util.pattern as pattern
import spack.util.string
import spack.util.url as url_util
from spack.util.crypto import bit_length, prefix_bits
from spack.util.web import FetchError
# The well-known stage source subdirectory name.
_source_path_subdir = "spack-src"
@@ -241,10 +241,7 @@ class Stage:
similar, and are intended to persist for only one run of spack.
"""
"""Shared dict of all stage locks."""
stage_locks: Dict[str, spack.util.lock.Lock] = {}
"""Most staging is managed by Spack. DIYStage is one exception."""
#: Most staging is managed by Spack. DIYStage is one exception.
managed_by_spack = True
def __init__(
@@ -330,17 +327,12 @@ def __init__(
# details on this approach.
self._lock = None
if lock:
if self.name not in Stage.stage_locks:
sha1 = hashlib.sha1(self.name.encode("utf-8")).digest()
lock_id = prefix_bits(sha1, bit_length(sys.maxsize))
stage_lock_path = os.path.join(get_stage_root(), ".lock")
tty.debug("Creating stage lock {0}".format(self.name))
Stage.stage_locks[self.name] = spack.util.lock.Lock(
stage_lock_path, start=lock_id, length=1, desc=self.name
)
self._lock = Stage.stage_locks[self.name]
sha1 = hashlib.sha1(self.name.encode("utf-8")).digest()
lock_id = prefix_bits(sha1, bit_length(sys.maxsize))
stage_lock_path = os.path.join(get_stage_root(), ".lock")
self._lock = spack.util.lock.Lock(
stage_lock_path, start=lock_id, length=1, desc=self.name
)
# When stages are reused, we need to know whether to re-create
# it. This marks whether it has been created/destroyed.
@@ -522,7 +514,7 @@ def print_errors(errors):
self.fetcher = self.default_fetcher
default_msg = "All fetchers failed for {0}".format(self.name)
raise FetchError(err_msg or default_msg, None)
raise spack.error.FetchError(err_msg or default_msg, None)
print_errors(errors)
@@ -868,45 +860,47 @@ def purge():
os.remove(stage_path)
def get_checksums_for_versions(url_dict, name, **kwargs):
"""Fetches and checksums archives from URLs.
def get_checksums_for_versions(
url_by_version: Dict[str, str],
package_name: str,
*,
batch: bool = False,
first_stage_function: Optional[Callable[[Stage, str], None]] = None,
keep_stage: bool = False,
concurrency: Optional[int] = None,
fetch_options: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""Computes the checksums for each version passed in input, and returns the results.
This function is called by both ``spack checksum`` and ``spack
create``. The ``first_stage_function`` argument allows the caller to
inspect the first downloaded archive, e.g., to determine the build
system.
Archives are fetched according to the usl dictionary passed as input.
The ``first_stage_function`` argument allows the caller to inspect the first downloaded
archive, e.g., to determine the build system.
Args:
url_dict (dict): A dictionary of the form: version -> URL
name (str): The name of the package
first_stage_function (typing.Callable): function that takes a Stage and a URL;
this is run on the stage of the first URL downloaded
keep_stage (bool): whether to keep staging area when command completes
batch (bool): whether to ask user how many versions to fetch (false)
or fetch all versions (true)
fetch_options (dict): Options used for the fetcher (such as timeout
or cookies)
url_by_version: URL keyed by version
package_name: name of the package
first_stage_function: function that takes a Stage and a URL; this is run on the stage
of the first URL downloaded
keep_stage: whether to keep staging area when command completes
batch: whether to ask user how many versions to fetch (false) or fetch all versions (true)
fetch_options: options used for the fetcher (such as timeout or cookies)
concurrency: maximum number of workers to use for retrieving archives
Returns:
(dict): A dictionary of the form: version -> checksum
A dictionary mapping each version to the corresponding checksum
"""
batch = kwargs.get("batch", False)
fetch_options = kwargs.get("fetch_options", None)
first_stage_function = kwargs.get("first_stage_function", None)
keep_stage = kwargs.get("keep_stage", False)
sorted_versions = sorted(url_dict.keys(), reverse=True)
sorted_versions = sorted(url_by_version.keys(), reverse=True)
# Find length of longest string in the list for padding
max_len = max(len(str(v)) for v in sorted_versions)
num_ver = len(sorted_versions)
tty.msg(
"Found {0} version{1} of {2}:".format(num_ver, "" if num_ver == 1 else "s", name),
f"Found {spack.util.string.plural(num_ver, 'version')} of {package_name}:",
"",
*llnl.util.lang.elide_list(
["{0:{1}} {2}".format(str(v), max_len, url_dict[v]) for v in sorted_versions]
["{0:{1}} {2}".format(str(v), max_len, url_by_version[v]) for v in sorted_versions]
),
)
print()
@@ -922,50 +916,76 @@ def get_checksums_for_versions(url_dict, name, **kwargs):
tty.die("Aborted.")
versions = sorted_versions[:archives_to_fetch]
urls = [url_dict[v] for v in versions]
search_arguments = [(url_by_version[v], v) for v in versions]
tty.debug("Downloading...")
version_hashes = {}
i = 0
errors = []
for url, version in zip(urls, versions):
try:
if fetch_options:
url_or_fs = fs.URLFetchStrategy(url, fetch_options=fetch_options)
else:
url_or_fs = url
with Stage(url_or_fs, keep=keep_stage) as stage:
# Fetch the archive
stage.fetch()
if i == 0 and first_stage_function:
# Only run first_stage_function the first time,
# no need to run it every time
first_stage_function(stage, url)
version_hashes, errors = {}, []
# Checksum the archive and add it to the list
version_hashes[version] = spack.util.crypto.checksum(
hashlib.sha256, stage.archive_file
)
i += 1
except FailedDownloadError:
errors.append("Failed to fetch {0}".format(url))
except Exception as e:
tty.msg("Something failed on {0}, skipping. ({1})".format(url, e))
# Don't spawn 16 processes when we need to fetch 2 urls
if concurrency is not None:
concurrency = min(concurrency, len(search_arguments))
else:
concurrency = min(os.cpu_count() or 1, len(search_arguments))
for msg in errors:
tty.debug(msg)
# The function might have side effects in memory, that would not be reflected in the
# parent process, if run in a child process. If this pattern happens frequently, we
# can move this function call *after* having distributed the work to executors.
if first_stage_function is not None:
(url, version), search_arguments = search_arguments[0], search_arguments[1:]
checksum, error = _fetch_and_checksum(url, fetch_options, keep_stage, first_stage_function)
if error is not None:
errors.append(error)
if checksum is not None:
version_hashes[version] = checksum
with concurrent.futures.ProcessPoolExecutor(max_workers=concurrency) as executor:
results = []
for url, version in search_arguments:
future = executor.submit(_fetch_and_checksum, url, fetch_options, keep_stage)
results.append((version, future))
for version, future in results:
checksum, error = future.result()
if error is not None:
errors.append(error)
continue
version_hashes[version] = checksum
for msg in errors:
tty.debug(msg)
if not version_hashes:
tty.die("Could not fetch any versions for {0}".format(name))
tty.die(f"Could not fetch any versions for {package_name}")
num_hash = len(version_hashes)
tty.debug(
"Checksummed {0} version{1} of {2}:".format(num_hash, "" if num_hash == 1 else "s", name)
)
tty.debug(f"Checksummed {num_hash} version{'' if num_hash == 1 else 's'} of {package_name}:")
return version_hashes
def _fetch_and_checksum(url, options, keep_stage, action_fn=None):
try:
url_or_fs = url
if options:
url_or_fs = fs.URLFetchStrategy(url, fetch_options=options)
with Stage(url_or_fs, keep=keep_stage) as stage:
# Fetch the archive
stage.fetch()
if action_fn is not None:
# Only run first_stage_function the first time,
# no need to run it every time
action_fn(stage, url)
# Checksum the archive and add it to the list
checksum = spack.util.crypto.checksum(hashlib.sha256, stage.archive_file)
return checksum, None
except FailedDownloadError:
return None, f"[WORKER] Failed to fetch {url}"
except Exception as e:
return None, f"[WORKER] Something failed on {url}, skipping. ({e})"
class StageError(spack.error.SpackError):
""" "Superclass for all errors encountered during staging."""

View File

@@ -14,7 +14,6 @@
import spack.spec
import spack.store
from spack.main import SpackCommand, SpackCommandError
from spack.util.web import FetchError
pytestmark = pytest.mark.usefixtures("config", "mutable_mock_repo")
@@ -208,7 +207,7 @@ def test_env_aware_spec(mutable_mock_env_path):
[
("develop-branch-version", "f3c7206350ac8ee364af687deaae5c574dcfca2c=develop", None),
("develop-branch-version", "git." + "a" * 40 + "=develop", None),
("callpath", "f3c7206350ac8ee364af687deaae5c574dcfca2c=1.0", FetchError),
("callpath", "f3c7206350ac8ee364af687deaae5c574dcfca2c=1.0", spack.error.FetchError),
("develop-branch-version", "git.foo=0.2.15", None),
],
)

View File

@@ -36,6 +36,7 @@
import spack.database
import spack.directory_layout
import spack.environment as ev
import spack.error
import spack.package_base
import spack.package_prefs
import spack.paths
@@ -52,7 +53,6 @@
import spack.util.url as url_util
from spack.fetch_strategy import URLFetchStrategy
from spack.util.pattern import Bunch
from spack.util.web import FetchError
def ensure_configuration_fixture_run_before(request):
@@ -472,7 +472,7 @@ def fetcher(self, target_path, digest, **kwargs):
class MockCacheFetcher:
def fetch(self):
raise FetchError("Mock cache always fails for tests")
raise spack.error.FetchError("Mock cache always fails for tests")
def __str__(self):
return "[mock fetch cache]"

View File

@@ -8,9 +8,9 @@
import pytest
import spack.config
import spack.error
import spack.fetch_strategy
import spack.stage
from spack.util.web import FetchError
@pytest.mark.parametrize("_fetch_method", ["curl", "urllib"])
@@ -33,7 +33,7 @@ def test_gcsfetchstrategy_bad_url(tmpdir, _fetch_method):
with spack.stage.Stage(fetcher, path=testpath) as stage:
assert stage is not None
assert fetcher.archive_file is None
with pytest.raises(FetchError):
with pytest.raises(spack.error.FetchError):
fetcher.fetch()

View File

@@ -20,6 +20,7 @@
import spack.binary_distribution as bindist
import spack.cmd.buildcache as buildcache
import spack.error
import spack.package_base
import spack.repo
import spack.store
@@ -522,7 +523,7 @@ def _instr(pkg):
monkeypatch.setattr(spack.package_base.PackageBase, "download_instr", _instr)
expected = spec.package.download_instr if manual else "All fetchers failed"
with pytest.raises(spack.util.web.FetchError, match=expected):
with pytest.raises(spack.error.FetchError, match=expected):
spec.package.do_fetch()

View File

@@ -8,9 +8,9 @@
import pytest
import spack.config as spack_config
import spack.error
import spack.fetch_strategy as spack_fs
import spack.stage as spack_stage
from spack.util.web import FetchError
@pytest.mark.parametrize("_fetch_method", ["curl", "urllib"])
@@ -33,7 +33,7 @@ def test_s3fetchstrategy_bad_url(tmpdir, _fetch_method):
with spack_stage.Stage(fetcher, path=testpath) as stage:
assert stage is not None
assert fetcher.archive_file is None
with pytest.raises(FetchError):
with pytest.raises(spack.error.FetchError):
fetcher.fetch()

View File

@@ -16,6 +16,7 @@
from llnl.util.filesystem import getuid, mkdirp, partition_path, touch, working_dir
import spack.error
import spack.paths
import spack.stage
import spack.util.executable
@@ -23,7 +24,6 @@
from spack.resource import Resource
from spack.stage import DIYStage, ResourceStage, Stage, StageComposite
from spack.util.path import canonicalize_path
from spack.util.web import FetchError
# The following values are used for common fetch and stage mocking fixtures:
_archive_base = "test-files"
@@ -522,7 +522,7 @@ def test_no_search_mirror_only(self, failing_fetch_strategy, failing_search_fn):
with stage:
try:
stage.fetch(mirror_only=True)
except FetchError:
except spack.error.FetchError:
pass
check_destroy(stage, self.stage_name)
@@ -537,7 +537,7 @@ def test_search_if_default_fails(self, failing_fetch_strategy, search_fn, err_ms
stage = Stage(failing_fetch_strategy, name=self.stage_name, search_fn=search_fn)
with stage:
with pytest.raises(FetchError, match=expected):
with pytest.raises(spack.error.FetchError, match=expected):
stage.fetch(mirror_only=False, err_msg=err_msg)
check_destroy(stage, self.stage_name)

View File

@@ -13,6 +13,7 @@
from llnl.util.filesystem import is_exe, working_dir
import spack.config
import spack.error
import spack.fetch_strategy as fs
import spack.repo
import spack.util.crypto as crypto
@@ -349,7 +350,7 @@ def _which(*args, **kwargs):
def test_url_fetch_text_without_url(tmpdir):
with pytest.raises(web_util.FetchError, match="URL is required"):
with pytest.raises(spack.error.FetchError, match="URL is required"):
web_util.fetch_url_text(None)
@@ -366,18 +367,18 @@ def _which(*args, **kwargs):
monkeypatch.setattr(spack.util.web, "which", _which)
with spack.config.override("config:url_fetch_method", "curl"):
with pytest.raises(web_util.FetchError, match="Missing required curl"):
with pytest.raises(spack.error.FetchError, match="Missing required curl"):
web_util.fetch_url_text("https://github.com/")
def test_url_check_curl_errors():
"""Check that standard curl error returncodes raise expected errors."""
# Check returncode 22 (i.e., 404)
with pytest.raises(web_util.FetchError, match="not found"):
with pytest.raises(spack.error.FetchError, match="not found"):
web_util.check_curl_code(22)
# Check returncode 60 (certificate error)
with pytest.raises(web_util.FetchError, match="invalid certificate"):
with pytest.raises(spack.error.FetchError, match="invalid certificate"):
web_util.check_curl_code(60)
@@ -394,7 +395,7 @@ def _which(*args, **kwargs):
monkeypatch.setattr(spack.util.web, "which", _which)
with spack.config.override("config:url_fetch_method", "curl"):
with pytest.raises(web_util.FetchError, match="Missing required curl"):
with pytest.raises(spack.error.FetchError, match="Missing required curl"):
web_util.url_exists("https://github.com/")
@@ -409,7 +410,7 @@ def _read_from_url(*args, **kwargs):
monkeypatch.setattr(spack.util.web, "read_from_url", _read_from_url)
with spack.config.override("config:url_fetch_method", "urllib"):
with pytest.raises(web_util.FetchError, match="failed with error code"):
with pytest.raises(spack.error.FetchError, match="failed with error code"):
web_util.fetch_url_text("https://github.com/")
@@ -420,5 +421,5 @@ def _raise_web_error(*args, **kwargs):
monkeypatch.setattr(spack.util.web, "read_from_url", _raise_web_error)
with spack.config.override("config:url_fetch_method", "urllib"):
with pytest.raises(web_util.FetchError, match="fetch failed to verify"):
with pytest.raises(spack.error.FetchError, match="fetch failed to verify"):
web_util.fetch_url_text("https://github.com/")

View File

@@ -98,7 +98,7 @@ def test_spider(depth, expected_found, expected_not_found, expected_text):
def test_spider_no_response(monkeypatch):
# Mock the absence of a response
monkeypatch.setattr(spack.util.web, "read_from_url", lambda x, y: (None, None, None))
pages, links = spack.util.web.spider(root, depth=0)
pages, links, _, _ = spack.util.web._spider(root, collect_nested=False, _visited=set())
assert not pages and not links

View File

@@ -4,9 +4,9 @@
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import codecs
import concurrent.futures
import email.message
import errno
import multiprocessing.pool
import os
import os.path
import re
@@ -17,7 +17,7 @@
import urllib.parse
from html.parser import HTMLParser
from pathlib import Path, PurePosixPath
from typing import IO, Optional
from typing import IO, Dict, List, Optional, Set, Union
from urllib.error import HTTPError, URLError
from urllib.request import HTTPSHandler, Request, build_opener
@@ -257,11 +257,11 @@ def check_curl_code(returncode):
if returncode != 0:
if returncode == 22:
# This is a 404. Curl will print the error.
raise FetchError("URL was not found!")
raise spack.error.FetchError("URL was not found!")
if returncode == 60:
# This is a certificate error. Suggest spack -k
raise FetchError(
raise spack.error.FetchError(
"Curl was unable to fetch due to invalid certificate. "
"This is either an attack, or your cluster's SSL "
"configuration is bad. If you believe your SSL "
@@ -270,7 +270,7 @@ def check_curl_code(returncode):
"Use this at your own risk."
)
raise FetchError("Curl failed with error {0}".format(returncode))
raise spack.error.FetchError("Curl failed with error {0}".format(returncode))
def _curl(curl=None):
@@ -279,7 +279,7 @@ def _curl(curl=None):
curl = which("curl", required=True)
except CommandNotFoundError as exc:
tty.error(str(exc))
raise FetchError("Missing required curl fetch method")
raise spack.error.FetchError("Missing required curl fetch method")
return curl
@@ -307,7 +307,7 @@ def fetch_url_text(url, curl=None, dest_dir="."):
Raises FetchError if the curl returncode indicates failure
"""
if not url:
raise FetchError("A URL is required to fetch its text")
raise spack.error.FetchError("A URL is required to fetch its text")
tty.debug("Fetching text at {0}".format(url))
@@ -319,7 +319,7 @@ def fetch_url_text(url, curl=None, dest_dir="."):
if fetch_method == "curl":
curl_exe = _curl(curl)
if not curl_exe:
raise FetchError("Missing required fetch method (curl)")
raise spack.error.FetchError("Missing required fetch method (curl)")
curl_args = ["-O"]
curl_args.extend(base_curl_fetch_args(url))
@@ -337,7 +337,9 @@ def fetch_url_text(url, curl=None, dest_dir="."):
returncode = response.getcode()
if returncode and returncode != 200:
raise FetchError("Urllib failed with error code {0}".format(returncode))
raise spack.error.FetchError(
"Urllib failed with error code {0}".format(returncode)
)
output = codecs.getreader("utf-8")(response).read()
if output:
@@ -348,7 +350,7 @@ def fetch_url_text(url, curl=None, dest_dir="."):
return path
except SpackWebError as err:
raise FetchError("Urllib fetch failed to verify url: {0}".format(str(err)))
raise spack.error.FetchError("Urllib fetch failed to verify url: {0}".format(str(err)))
return None
@@ -543,170 +545,160 @@ def list_url(url, recursive=False):
return gcs.get_all_blobs(recursive=recursive)
def spider(root_urls, depth=0, concurrency=32):
def spider(root_urls: Union[str, List[str]], depth: int = 0, concurrency: Optional[int] = None):
"""Get web pages from root URLs.
If depth is specified (e.g., depth=2), then this will also follow
up to <depth> levels of links from each root.
If depth is specified (e.g., depth=2), then this will also follow up to <depth> levels
of links from each root.
Args:
root_urls (str or list): root urls used as a starting point
for spidering
depth (int): level of recursion into links
concurrency (int): number of simultaneous requests that can be sent
root_urls: root urls used as a starting point for spidering
depth: level of recursion into links
concurrency: number of simultaneous requests that can be sent
Returns:
A dict of pages visited (URL) mapped to their full text and the
set of visited links.
A dict of pages visited (URL) mapped to their full text and the set of visited links.
"""
# Cache of visited links, meant to be captured by the closure below
_visited = set()
def _spider(url, collect_nested):
"""Fetches URL and any pages it links to.
Prints out a warning only if the root can't be fetched; it ignores
errors with pages that the root links to.
Args:
url (str): url being fetched and searched for links
collect_nested (bool): whether we want to collect arguments
for nested spidering on the links found in this url
Returns:
A tuple of:
- pages: dict of pages visited (URL) mapped to their full text.
- links: set of links encountered while visiting the pages.
- spider_args: argument for subsequent call to spider
"""
pages = {} # dict from page URL -> text content.
links = set() # set of all links seen on visited pages.
subcalls = []
try:
response_url, _, response = read_from_url(url, "text/html")
if not response_url or not response:
return pages, links, subcalls
page = codecs.getreader("utf-8")(response).read()
pages[response_url] = page
# Parse out the include-fragments in the page
# https://github.github.io/include-fragment-element
include_fragment_parser = IncludeFragmentParser()
include_fragment_parser.feed(page)
fragments = set()
while include_fragment_parser.links:
raw_link = include_fragment_parser.links.pop()
abs_link = url_util.join(response_url, raw_link.strip(), resolve_href=True)
try:
# This seems to be text/html, though text/fragment+html is also used
fragment_response_url, _, fragment_response = read_from_url(
abs_link, "text/html"
)
except Exception as e:
msg = f"Error reading fragment: {(type(e), str(e))}:{traceback.format_exc()}"
tty.debug(msg)
if not fragment_response_url or not fragment_response:
continue
fragment = codecs.getreader("utf-8")(fragment_response).read()
fragments.add(fragment)
pages[fragment_response_url] = fragment
# Parse out the links in the page and all fragments
link_parser = LinkParser()
link_parser.feed(page)
for fragment in fragments:
link_parser.feed(fragment)
while link_parser.links:
raw_link = link_parser.links.pop()
abs_link = url_util.join(response_url, raw_link.strip(), resolve_href=True)
links.add(abs_link)
# Skip stuff that looks like an archive
if any(raw_link.endswith(s) for s in llnl.url.ALLOWED_ARCHIVE_TYPES):
continue
# Skip already-visited links
if abs_link in _visited:
continue
# If we're not at max depth, follow links.
if collect_nested:
subcalls.append((abs_link,))
_visited.add(abs_link)
except URLError as e:
tty.debug(str(e))
if hasattr(e, "reason") and isinstance(e.reason, ssl.SSLError):
tty.warn(
"Spack was unable to fetch url list due to a "
"certificate verification problem. You can try "
"running spack -k, which will not check SSL "
"certificates. Use this at your own risk."
)
except HTMLParseError as e:
# This error indicates that Python's HTML parser sucks.
msg = "Got an error parsing HTML."
tty.warn(msg, url, "HTMLParseError: " + str(e))
except Exception as e:
# Other types of errors are completely ignored,
# except in debug mode
tty.debug("Error in _spider: %s:%s" % (type(e), str(e)), traceback.format_exc())
finally:
tty.debug("SPIDER: [url={0}]".format(url))
return pages, links, subcalls
if isinstance(root_urls, str):
root_urls = [root_urls]
# Clear the local cache of visited pages before starting the search
_visited.clear()
current_depth = 0
pages, links, spider_args = {}, set(), []
collect = current_depth < depth
for root in root_urls:
root = urllib.parse.urlparse(root)
spider_args.append((root, collect))
_visited: Set[str] = set()
go_deeper = current_depth < depth
for root_str in root_urls:
root = urllib.parse.urlparse(root_str)
spider_args.append((root, go_deeper, _visited))
tp = multiprocessing.pool.ThreadPool(processes=concurrency)
try:
with concurrent.futures.ProcessPoolExecutor(max_workers=concurrency) as tp:
while current_depth <= depth:
tty.debug(
"SPIDER: [depth={0}, max_depth={1}, urls={2}]".format(
current_depth, depth, len(spider_args)
)
f"SPIDER: [depth={current_depth}, max_depth={depth}, urls={len(spider_args)}]"
)
results = tp.map(lang.star(_spider), spider_args)
results = [tp.submit(_spider, *one_search_args) for one_search_args in spider_args]
spider_args = []
collect = current_depth < depth
for sub_pages, sub_links, sub_spider_args in results:
sub_spider_args = [x + (collect,) for x in sub_spider_args]
go_deeper = current_depth < depth
for future in results:
sub_pages, sub_links, sub_spider_args, sub_visited = future.result()
_visited.update(sub_visited)
sub_spider_args = [(x, go_deeper, _visited) for x in sub_spider_args]
pages.update(sub_pages)
links.update(sub_links)
spider_args.extend(sub_spider_args)
current_depth += 1
finally:
tp.terminate()
tp.join()
return pages, links
def _spider(url: urllib.parse.ParseResult, collect_nested: bool, _visited: Set[str]):
"""Fetches URL and any pages it links to.
Prints out a warning only if the root can't be fetched; it ignores errors with pages
that the root links to.
Args:
url: url being fetched and searched for links
collect_nested: whether we want to collect arguments for nested spidering on the
links found in this url
_visited: links already visited
Returns:
A tuple of:
- pages: dict of pages visited (URL) mapped to their full text.
- links: set of links encountered while visiting the pages.
- spider_args: argument for subsequent call to spider
- visited: updated set of visited urls
"""
pages: Dict[str, str] = {} # dict from page URL -> text content.
links: Set[str] = set() # set of all links seen on visited pages.
subcalls: List[str] = []
try:
response_url, _, response = read_from_url(url, "text/html")
if not response_url or not response:
return pages, links, subcalls, _visited
page = codecs.getreader("utf-8")(response).read()
pages[response_url] = page
# Parse out the include-fragments in the page
# https://github.github.io/include-fragment-element
include_fragment_parser = IncludeFragmentParser()
include_fragment_parser.feed(page)
fragments = set()
while include_fragment_parser.links:
raw_link = include_fragment_parser.links.pop()
abs_link = url_util.join(response_url, raw_link.strip(), resolve_href=True)
try:
# This seems to be text/html, though text/fragment+html is also used
fragment_response_url, _, fragment_response = read_from_url(abs_link, "text/html")
except Exception as e:
msg = f"Error reading fragment: {(type(e), str(e))}:{traceback.format_exc()}"
tty.debug(msg)
if not fragment_response_url or not fragment_response:
continue
fragment = codecs.getreader("utf-8")(fragment_response).read()
fragments.add(fragment)
pages[fragment_response_url] = fragment
# Parse out the links in the page and all fragments
link_parser = LinkParser()
link_parser.feed(page)
for fragment in fragments:
link_parser.feed(fragment)
while link_parser.links:
raw_link = link_parser.links.pop()
abs_link = url_util.join(response_url, raw_link.strip(), resolve_href=True)
links.add(abs_link)
# Skip stuff that looks like an archive
if any(raw_link.endswith(s) for s in llnl.url.ALLOWED_ARCHIVE_TYPES):
continue
# Skip already-visited links
if abs_link in _visited:
continue
# If we're not at max depth, follow links.
if collect_nested:
subcalls.append(abs_link)
_visited.add(abs_link)
except URLError as e:
tty.debug(f"[SPIDER] Unable to read: {url}")
tty.debug(str(e), level=2)
if hasattr(e, "reason") and isinstance(e.reason, ssl.SSLError):
tty.warn(
"Spack was unable to fetch url list due to a "
"certificate verification problem. You can try "
"running spack -k, which will not check SSL "
"certificates. Use this at your own risk."
)
except HTMLParseError as e:
# This error indicates that Python's HTML parser sucks.
msg = "Got an error parsing HTML."
tty.warn(msg, url, "HTMLParseError: " + str(e))
except Exception as e:
# Other types of errors are completely ignored,
# except in debug mode
tty.debug(f"Error in _spider: {type(e)}:{str(e)}", traceback.format_exc())
finally:
tty.debug(f"SPIDER: [url={url}]")
return pages, links, subcalls, _visited
def get_header(headers, header_name):
"""Looks up a dict of headers for the given header value.
@@ -767,10 +759,6 @@ def parse_etag(header_value):
return valid.group(1) if valid else None
class FetchError(spack.error.SpackError):
"""Superclass for fetch-related errors."""
class SpackWebError(spack.error.SpackError):
"""Superclass for Spack web spidering errors."""