Simplify URLFetchStrategy (#45741)

This commit is contained in:
Harmen Stoppels 2024-08-19 11:34:13 +02:00 committed by GitHub
parent c65fd7e12d
commit 57769fac7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 93 additions and 233 deletions

View File

@ -54,7 +54,7 @@
import spack.version
import spack.version.git_ref_lookup
from spack.util.compression import decompressor_for
from spack.util.executable import CommandNotFoundError, which
from spack.util.executable import CommandNotFoundError, Executable, which
#: List of all fetch strategies, created by FetchStrategy metaclass.
all_strategies = []
@ -246,33 +246,28 @@ class URLFetchStrategy(FetchStrategy):
# these are checksum types. The generic 'checksum' is deprecated for
# specific hash names, but we need it for backward compatibility
optional_attrs = list(crypto.hashes.keys()) + ["checksum"]
optional_attrs = [*crypto.hashes.keys(), "checksum"]
def __init__(self, url=None, checksum=None, **kwargs):
def __init__(self, *, url: str, checksum: Optional[str] = None, **kwargs) -> None:
super().__init__(**kwargs)
# Prefer values in kwargs to the positionals.
self.url = kwargs.get("url", url)
self.url = url
self.mirrors = kwargs.get("mirrors", [])
# digest can be set as the first argument, or from an explicit
# kwarg by the hash name.
self.digest = kwargs.get("checksum", checksum)
self.digest: Optional[str] = checksum
for h in self.optional_attrs:
if h in kwargs:
self.digest = kwargs[h]
self.expand_archive = kwargs.get("expand", True)
self.extra_options = kwargs.get("fetch_options", {})
self._curl = None
self.extension = kwargs.get("extension", None)
if not self.url:
raise ValueError("URLFetchStrategy requires a url for fetching.")
self.expand_archive: bool = kwargs.get("expand", True)
self.extra_options: dict = kwargs.get("fetch_options", {})
self._curl: Optional[Executable] = None
self.extension: Optional[str] = kwargs.get("extension", None)
@property
def curl(self):
def curl(self) -> Executable:
if not self._curl:
self._curl = web_util.require_curl()
return self._curl
@ -348,8 +343,8 @@ def _fetch_urllib(self, url):
if os.path.lexists(save_file):
os.remove(save_file)
with open(save_file, "wb") as _open_file:
shutil.copyfileobj(response, _open_file)
with open(save_file, "wb") as f:
shutil.copyfileobj(response, f)
self._check_headers(str(response.headers))
@ -468,7 +463,7 @@ def check(self):
"""Check the downloaded archive against a checksum digest.
No-op if this stage checks code out of a repository."""
if not self.digest:
raise NoDigestError("Attempt to check URLFetchStrategy with no digest.")
raise NoDigestError(f"Attempt to check {self.__class__.__name__} with no digest.")
verify_checksum(self.archive_file, self.digest)
@ -479,8 +474,8 @@ def reset(self):
"""
if not self.archive_file:
raise NoArchiveFileError(
"Tried to reset URLFetchStrategy before fetching",
"Failed on reset() for URL %s" % self.url,
f"Tried to reset {self.__class__.__name__} before fetching",
f"Failed on reset() for URL{self.url}",
)
# Remove everything but the archive from the stage
@ -493,14 +488,10 @@ def reset(self):
self.expand()
def __repr__(self):
url = self.url if self.url else "no url"
return "%s<%s>" % (self.__class__.__name__, url)
return f"{self.__class__.__name__}<{self.url}>"
def __str__(self):
if self.url:
return self.url
else:
return "[no url]"
return self.url
@fetcher
@ -513,7 +504,7 @@ def fetch(self):
# check whether the cache file exists.
if not os.path.isfile(path):
raise NoCacheError("No cache of %s" % path)
raise NoCacheError(f"No cache of {path}")
# remove old symlink if one is there.
filename = self.stage.save_filename
@ -523,8 +514,8 @@ def fetch(self):
# Symlink to local cached archive.
symlink(path, filename)
# Remove link if checksum fails, or subsequent fetchers
# will assume they don't need to download.
# Remove link if checksum fails, or subsequent fetchers will assume they don't need to
# download.
if self.digest:
try:
self.check()
@ -533,12 +524,12 @@ def fetch(self):
raise
# Notify the user how we fetched.
tty.msg("Using cached archive: {0}".format(path))
tty.msg(f"Using cached archive: {path}")
class OCIRegistryFetchStrategy(URLFetchStrategy):
def __init__(self, url=None, checksum=None, **kwargs):
super().__init__(url, checksum, **kwargs)
def __init__(self, *, url: str, checksum: Optional[str] = None, **kwargs):
super().__init__(url=url, checksum=checksum, **kwargs)
self._urlopen = kwargs.get("_urlopen", spack.oci.opener.urlopen)
@ -1381,7 +1372,7 @@ def reset(self):
shutil.move(scrubbed, source_path)
def __str__(self):
return "[hg] %s" % self.url
return f"[hg] {self.url}"
@fetcher
@ -1390,47 +1381,16 @@ class S3FetchStrategy(URLFetchStrategy):
url_attr = "s3"
def __init__(self, *args, **kwargs):
try:
super().__init__(*args, **kwargs)
except ValueError:
if not kwargs.get("url"):
raise ValueError("S3FetchStrategy requires a url for fetching.")
@_needs_stage
def fetch(self):
if not self.url.startswith("s3://"):
raise spack.error.FetchError(
f"{self.__class__.__name__} can only fetch from s3:// urls."
)
if self.archive_file:
tty.debug(f"Already downloaded {self.archive_file}")
return
parsed_url = urllib.parse.urlparse(self.url)
if parsed_url.scheme != "s3":
raise spack.error.FetchError("S3FetchStrategy can only fetch from s3:// urls.")
basename = os.path.basename(parsed_url.path)
request = urllib.request.Request(
self.url, headers={"User-Agent": web_util.SPACK_USER_AGENT}
)
with working_dir(self.stage.path):
try:
response = web_util.urlopen(request)
except (TimeoutError, urllib.error.URLError) as e:
raise FailedDownloadError(e) from e
tty.debug(f"Fetching {self.url}")
with open(basename, "wb") as f:
shutil.copyfileobj(response, f)
content_type = web_util.get_header(response.headers, "Content-type")
if content_type == "text/html":
warn_content_type_mismatch(self.archive_file or "the archive")
if self.stage.save_filename:
fs.rename(os.path.join(self.stage.path, basename), self.stage.save_filename)
self._fetch_urllib(self.url)
if not self.archive_file:
raise FailedDownloadError(
RuntimeError(f"Missing archive {self.archive_file} after fetching")
@ -1443,46 +1403,17 @@ class GCSFetchStrategy(URLFetchStrategy):
url_attr = "gs"
def __init__(self, *args, **kwargs):
try:
super().__init__(*args, **kwargs)
except ValueError:
if not kwargs.get("url"):
raise ValueError("GCSFetchStrategy requires a url for fetching.")
@_needs_stage
def fetch(self):
if not self.url.startswith("gs"):
raise spack.error.FetchError(
f"{self.__class__.__name__} can only fetch from gs:// urls."
)
if self.archive_file:
tty.debug("Already downloaded {0}".format(self.archive_file))
tty.debug(f"Already downloaded {self.archive_file}")
return
parsed_url = urllib.parse.urlparse(self.url)
if parsed_url.scheme != "gs":
raise spack.error.FetchError("GCSFetchStrategy can only fetch from gs:// urls.")
basename = os.path.basename(parsed_url.path)
request = urllib.request.Request(
self.url, headers={"User-Agent": web_util.SPACK_USER_AGENT}
)
with working_dir(self.stage.path):
try:
response = web_util.urlopen(request)
except (TimeoutError, urllib.error.URLError) as e:
raise FailedDownloadError(e) from e
tty.debug(f"Fetching {self.url}")
with open(basename, "wb") as f:
shutil.copyfileobj(response, f)
content_type = web_util.get_header(response.headers, "Content-type")
if content_type == "text/html":
warn_content_type_mismatch(self.archive_file or "the archive")
if self.stage.save_filename:
os.rename(os.path.join(self.stage.path, basename), self.stage.save_filename)
self._fetch_urllib(self.url)
if not self.archive_file:
raise FailedDownloadError(
@ -1496,7 +1427,7 @@ class FetchAndVerifyExpandedFile(URLFetchStrategy):
as well as after expanding it."""
def __init__(self, url, archive_sha256: str, expanded_sha256: str):
super().__init__(url, archive_sha256)
super().__init__(url=url, checksum=archive_sha256)
self.expanded_sha256 = expanded_sha256
def expand(self):
@ -1538,14 +1469,14 @@ def stable_target(fetcher):
return False
def from_url(url):
def from_url(url: str) -> URLFetchStrategy:
"""Given a URL, find an appropriate fetch strategy for it.
Currently just gives you a URLFetchStrategy that uses curl.
TODO: make this return appropriate fetch strategies for other
types of URLs.
"""
return URLFetchStrategy(url)
return URLFetchStrategy(url=url)
def from_kwargs(**kwargs):
@ -1614,10 +1545,12 @@ def _check_version_attributes(fetcher, pkg, version):
def _extrapolate(pkg, version):
"""Create a fetcher from an extrapolated URL for this version."""
try:
return URLFetchStrategy(pkg.url_for_version(version), fetch_options=pkg.fetch_options)
return URLFetchStrategy(url=pkg.url_for_version(version), fetch_options=pkg.fetch_options)
except spack.package_base.NoURLError:
msg = "Can't extrapolate a URL for version %s " "because package %s defines no URLs"
raise ExtrapolationError(msg % (version, pkg.name))
raise ExtrapolationError(
f"Can't extrapolate a URL for version {version} because "
f"package {pkg.name} defines no URLs"
)
def _from_merged_attrs(fetcher, pkg, version):
@ -1733,11 +1666,9 @@ def for_package_version(pkg, version=None):
raise InvalidArgsError(pkg, version, **args)
def from_url_scheme(url, *args, **kwargs):
def from_url_scheme(url: str, **kwargs):
"""Finds a suitable FetchStrategy by matching its url_attr with the scheme
in the given url."""
url = kwargs.get("url", url)
parsed_url = urllib.parse.urlparse(url, scheme="file")
scheme_mapping = kwargs.get("scheme_mapping") or {
@ -1754,11 +1685,9 @@ def from_url_scheme(url, *args, **kwargs):
for fetcher in all_strategies:
url_attr = getattr(fetcher, "url_attr", None)
if url_attr and url_attr == scheme:
return fetcher(url, *args, **kwargs)
return fetcher(url=url, **kwargs)
raise ValueError(
'No FetchStrategy found for url with scheme: "{SCHEME}"'.format(SCHEME=parsed_url.scheme)
)
raise ValueError(f'No FetchStrategy found for url with scheme: "{parsed_url.scheme}"')
def from_list_url(pkg):
@ -1783,7 +1712,9 @@ def from_list_url(pkg):
)
# construct a fetcher
return URLFetchStrategy(url_from_list, checksum, fetch_options=pkg.fetch_options)
return URLFetchStrategy(
url=url_from_list, checksum=checksum, fetch_options=pkg.fetch_options
)
except KeyError as e:
tty.debug(e)
tty.msg("Cannot find version %s in url_list" % pkg.version)
@ -1811,10 +1742,10 @@ def store(self, fetcher, relative_dest):
mkdirp(os.path.dirname(dst))
fetcher.archive(dst)
def fetcher(self, target_path, digest, **kwargs):
def fetcher(self, target_path: str, digest: Optional[str], **kwargs) -> CacheURLFetchStrategy:
path = os.path.join(self.root, target_path)
url = url_util.path_to_file_url(path)
return CacheURLFetchStrategy(url, digest, **kwargs)
return CacheURLFetchStrategy(url=url, checksum=digest, **kwargs)
def destroy(self):
shutil.rmtree(self.root, ignore_errors=True)

View File

@ -390,7 +390,7 @@ def make_stage(
) -> spack.stage.Stage:
_urlopen = _urlopen or spack.oci.opener.urlopen
fetch_strategy = spack.fetch_strategy.OCIRegistryFetchStrategy(
url, checksum=digest.digest, _urlopen=_urlopen
url=url, checksum=digest.digest, _urlopen=_urlopen
)
# Use blobs/<alg>/<encoded> as the cache path, which follows
# the OCI Image Layout Specification. What's missing though,

View File

@ -319,7 +319,7 @@ def stage(self) -> "spack.stage.Stage":
self.url, archive_sha256=self.archive_sha256, expanded_sha256=self.sha256
)
else:
fetcher = fs.URLFetchStrategy(self.url, sha256=self.sha256, expand=False)
fetcher = fs.URLFetchStrategy(url=self.url, sha256=self.sha256, expand=False)
# The same package can have multiple patches with the same name but
# with different contents, therefore apply a subset of the hash.

View File

@ -501,7 +501,7 @@ def _generate_fetchers(self, mirror_only=False) -> Generator[fs.FetchStrategy, N
fetchers[:0] = (
fs.from_url_scheme(
url_util.join(mirror.fetch_url, rel_path),
digest,
checksum=digest,
expand=expand,
extension=extension,
)
@ -525,13 +525,13 @@ def _generate_fetchers(self, mirror_only=False) -> Generator[fs.FetchStrategy, N
if self.search_fn and not mirror_only:
yield from self.search_fn()
def fetch(self, mirror_only=False, err_msg=None):
def fetch(self, mirror_only: bool = False, err_msg: Optional[str] = None) -> None:
"""Retrieves the code or archive
Args:
mirror_only (bool): only fetch from a mirror
err_msg (str or None): the error message to display if all fetchers
fail or ``None`` for the default fetch failure message
mirror_only: only fetch from a mirror
err_msg: the error message to display if all fetchers fail or ``None`` for the default
fetch failure message
"""
errors: List[str] = []
for fetcher in self._generate_fetchers(mirror_only):
@ -593,16 +593,19 @@ def steal_source(self, dest):
self.destroy()
def check(self):
"""Check the downloaded archive against a checksum digest.
No-op if this stage checks code out of a repository."""
"""Check the downloaded archive against a checksum digest."""
if self.fetcher is not self.default_fetcher and self.skip_checksum_for_mirror:
cache = isinstance(self.fetcher, fs.CacheURLFetchStrategy)
if cache:
secure_msg = "your download cache is in a secure location"
else:
secure_msg = "you trust this mirror and have a secure connection"
tty.warn(
"Fetching from mirror without a checksum!",
"This package is normally checked out from a version "
"control system, but it has been archived on a spack "
"mirror. This means we cannot know a checksum for the "
"tarball in advance. Be sure that your connection to "
"this mirror is secure!",
f"Using {'download cache' if cache else 'a mirror'} instead of version control",
"The required sources are normally checked out from a version control system, "
f"but have been archived {'in download cache' if cache else 'on a mirror'}: "
f"{self.fetcher}. Spack lacks a tree hash to verify the integrity of this "
f"archive. Make sure {secure_msg}.",
)
elif spack.config.get("config:checksum"):
self.fetcher.check()
@ -1171,7 +1174,7 @@ def _fetch_and_checksum(url, options, keep_stage, action_fn=None):
try:
url_or_fs = url
if options:
url_or_fs = fs.URLFetchStrategy(url, fetch_options=options)
url_or_fs = fs.URLFetchStrategy(url=url, fetch_options=options)
with Stage(url_or_fs, keep=keep_stage) as stage:
# Fetch the archive

View File

@ -1003,7 +1003,7 @@ def temporary_store(tmpdir, request):
def mock_fetch(mock_archive, monkeypatch):
"""Fake the URL for a package so it downloads from a file."""
monkeypatch.setattr(
spack.package_base.PackageBase, "fetcher", URLFetchStrategy(mock_archive.url)
spack.package_base.PackageBase, "fetcher", URLFetchStrategy(url=mock_archive.url)
)

View File

@ -3,54 +3,21 @@
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import os
import pytest
import spack.config
import spack.error
import spack.fetch_strategy
import spack.stage
@pytest.mark.parametrize("_fetch_method", ["curl", "urllib"])
def test_gcsfetchstrategy_without_url(_fetch_method):
"""Ensure constructor with no URL fails."""
with spack.config.override("config:url_fetch_method", _fetch_method):
with pytest.raises(ValueError):
spack.fetch_strategy.GCSFetchStrategy(None)
@pytest.mark.parametrize("_fetch_method", ["curl", "urllib"])
def test_gcsfetchstrategy_bad_url(tmpdir, _fetch_method):
"""Ensure fetch with bad URL fails as expected."""
testpath = str(tmpdir)
with spack.config.override("config:url_fetch_method", _fetch_method):
fetcher = spack.fetch_strategy.GCSFetchStrategy(url="file:///does-not-exist")
assert fetcher is not None
with spack.stage.Stage(fetcher, path=testpath) as stage:
assert stage is not None
assert fetcher.archive_file is None
with pytest.raises(spack.error.FetchError):
fetcher.fetch()
@pytest.mark.parametrize("_fetch_method", ["curl", "urllib"])
def test_gcsfetchstrategy_downloaded(tmpdir, _fetch_method):
def test_gcsfetchstrategy_downloaded(tmp_path):
"""Ensure fetch with archive file already downloaded is a noop."""
testpath = str(tmpdir)
archive = os.path.join(testpath, "gcs.tar.gz")
archive = tmp_path / "gcs.tar.gz"
with spack.config.override("config:url_fetch_method", _fetch_method):
class Archived_GCSFS(spack.fetch_strategy.GCSFetchStrategy):
@property
def archive_file(self):
return str(archive)
class Archived_GCSFS(spack.fetch_strategy.GCSFetchStrategy):
@property
def archive_file(self):
return archive
url = "gcs:///{0}".format(archive)
fetcher = Archived_GCSFS(url=url)
with spack.stage.Stage(fetcher, path=testpath):
fetcher.fetch()
fetcher = Archived_GCSFS(url="gs://example/gcs.tar.gz")
with spack.stage.Stage(fetcher, path=str(tmp_path)):
fetcher.fetch()

View File

@ -205,7 +205,7 @@ def test_invalid_json_mirror_collection(invalid_json, error_message):
def test_mirror_archive_paths_no_version(mock_packages, mock_archive):
spec = Spec("trivial-install-test-package@=nonexistingversion").concretized()
fetcher = spack.fetch_strategy.URLFetchStrategy(mock_archive.url)
fetcher = spack.fetch_strategy.URLFetchStrategy(url=mock_archive.url)
spack.mirror.mirror_archive_paths(fetcher, "per-package-ref", spec)

View File

@ -48,7 +48,7 @@
def test_buildcache(mock_archive, tmp_path, monkeypatch, mutable_config):
# Install a test package
spec = Spec("trivial-install-test-package").concretized()
monkeypatch.setattr(spec.package, "fetcher", URLFetchStrategy(mock_archive.url))
monkeypatch.setattr(spec.package, "fetcher", URLFetchStrategy(url=mock_archive.url))
spec.package.do_install()
pkghash = "/" + str(spec.dag_hash(7))

View File

@ -3,54 +3,19 @@
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import os
import pytest
import spack.config as spack_config
import spack.error
import spack.fetch_strategy as spack_fs
import spack.stage as spack_stage
@pytest.mark.parametrize("_fetch_method", ["curl", "urllib"])
def test_s3fetchstrategy_sans_url(_fetch_method):
"""Ensure constructor with no URL fails."""
with spack_config.override("config:url_fetch_method", _fetch_method):
with pytest.raises(ValueError):
spack_fs.S3FetchStrategy(None)
@pytest.mark.parametrize("_fetch_method", ["curl", "urllib"])
def test_s3fetchstrategy_bad_url(tmpdir, _fetch_method):
"""Ensure fetch with bad URL fails as expected."""
testpath = str(tmpdir)
with spack_config.override("config:url_fetch_method", _fetch_method):
fetcher = spack_fs.S3FetchStrategy(url="file:///does-not-exist")
assert fetcher is not None
with spack_stage.Stage(fetcher, path=testpath) as stage:
assert stage is not None
assert fetcher.archive_file is None
with pytest.raises(spack.error.FetchError):
fetcher.fetch()
@pytest.mark.parametrize("_fetch_method", ["curl", "urllib"])
def test_s3fetchstrategy_downloaded(tmpdir, _fetch_method):
def test_s3fetchstrategy_downloaded(tmp_path):
"""Ensure fetch with archive file already downloaded is a noop."""
testpath = str(tmpdir)
archive = os.path.join(testpath, "s3.tar.gz")
archive = tmp_path / "s3.tar.gz"
with spack_config.override("config:url_fetch_method", _fetch_method):
class Archived_S3FS(spack_fs.S3FetchStrategy):
@property
def archive_file(self):
return archive
class Archived_S3FS(spack_fs.S3FetchStrategy):
@property
def archive_file(self):
return archive
url = "s3:///{0}".format(archive)
fetcher = Archived_S3FS(url=url)
with spack_stage.Stage(fetcher, path=testpath):
fetcher.fetch()
fetcher = Archived_S3FS(url="s3://example/s3.tar.gz")
with spack_stage.Stage(fetcher, path=str(tmp_path)):
fetcher.fetch()

View File

@ -76,12 +76,6 @@ def fn_urls(v):
return factory
def test_urlfetchstrategy_sans_url():
"""Ensure constructor with no URL fails."""
with pytest.raises(ValueError):
fs.URLFetchStrategy(None)
@pytest.mark.parametrize("method", ["curl", "urllib"])
def test_urlfetchstrategy_bad_url(tmp_path, mutable_config, method):
"""Ensure fetch with bad URL fails as expected."""
@ -267,7 +261,7 @@ def is_true():
monkeypatch.setattr(sys.stdout, "isatty", is_true)
monkeypatch.setattr(tty, "msg_enabled", is_true)
with spack.config.override("config:url_fetch_method", "curl"):
fetcher = fs.URLFetchStrategy(mock_archive.url)
fetcher = fs.URLFetchStrategy(url=mock_archive.url)
with Stage(fetcher, path=testpath) as stage:
assert fetcher.archive_file is None
stage.fetch()
@ -280,7 +274,7 @@ def is_true():
def test_url_extra_fetch(tmp_path, mutable_config, mock_archive, _fetch_method):
"""Ensure a fetch after downloading is effectively a no-op."""
mutable_config.set("config:url_fetch_method", _fetch_method)
fetcher = fs.URLFetchStrategy(mock_archive.url)
fetcher = fs.URLFetchStrategy(url=mock_archive.url)
with Stage(fetcher, path=str(tmp_path)) as stage:
assert fetcher.archive_file is None
stage.fetch()