fetch_strategy.py: show progress (#50003)

Show progress meter for fetches when `stdout` is a `tty`.

* fetch_strategy.py: show progress
* "Fetched: x MB at y MB/s"
* add tests, show % if content-length
This commit is contained in:
Harmen Stoppels 2025-04-11 21:39:42 +02:00 committed by GitHub
parent 8fc1ccc686
commit cc3d40d9d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 256 additions and 3 deletions

View File

@ -27,11 +27,14 @@
import os import os
import re import re
import shutil import shutil
import sys
import time
import urllib.error import urllib.error
import urllib.parse import urllib.parse
import urllib.request import urllib.request
import urllib.response
from pathlib import PurePath from pathlib import PurePath
from typing import List, Optional from typing import Callable, List, Mapping, Optional
import llnl.url import llnl.url
import llnl.util import llnl.util
@ -219,6 +222,114 @@ def mirror_id(self):
"""BundlePackages don't have a mirror id.""" """BundlePackages don't have a mirror id."""
def _format_speed(total_bytes: int, elapsed: float) -> str:
"""Return a human-readable average download speed string."""
elapsed = 1 if elapsed <= 0 else elapsed # avoid divide by zero
speed = total_bytes / elapsed
if speed >= 1e9:
return f"{speed / 1e9:6.1f} GB/s"
elif speed >= 1e6:
return f"{speed / 1e6:6.1f} MB/s"
elif speed >= 1e3:
return f"{speed / 1e3:6.1f} KB/s"
return f"{speed:6.1f} B/s"
def _format_bytes(total_bytes: int) -> str:
"""Return a human-readable total bytes string."""
if total_bytes >= 1e9:
return f"{total_bytes / 1e9:7.2f} GB"
elif total_bytes >= 1e6:
return f"{total_bytes / 1e6:7.2f} MB"
elif total_bytes >= 1e3:
return f"{total_bytes / 1e3:7.2f} KB"
return f"{total_bytes:7.2f} B"
class FetchProgress:
#: Characters to rotate in the spinner.
spinner = ["|", "/", "-", "\\"]
def __init__(
self,
total_bytes: Optional[int] = None,
enabled: bool = True,
get_time: Callable[[], float] = time.time,
) -> None:
"""Initialize a FetchProgress instance.
Args:
total_bytes: Total number of bytes to download, if known.
enabled: Whether to print progress information.
get_time: Function to get the current time."""
#: Number of bytes downloaded so far.
self.current_bytes = 0
#: Delta time between progress prints
self.delta = 0.1
#: Whether to print progress information.
self.enabled = enabled
#: Function to get the current time.
self.get_time = get_time
#: Time of last progress print to limit output
self.last_printed = 0.0
#: Time of start of download
self.start_time = get_time() if enabled else 0.0
#: Total number of bytes to download, if known.
self.total_bytes = total_bytes if total_bytes and total_bytes > 0 else 0
#: Index of spinner character to print (used if total bytes is unknown)
self.index = 0
@classmethod
def from_headers(
cls,
headers: Mapping[str, str],
enabled: bool = True,
get_time: Callable[[], float] = time.time,
) -> "FetchProgress":
"""Create a FetchProgress instance from HTTP headers."""
# headers.get is case-insensitive if it's from a HTTPResponse object.
content_length = headers.get("Content-Length")
try:
total_bytes = int(content_length) if content_length else None
except ValueError:
total_bytes = None
return cls(total_bytes=total_bytes, enabled=enabled, get_time=get_time)
def advance(self, num_bytes: int, out=sys.stdout) -> None:
if not self.enabled:
return
self.current_bytes += num_bytes
self.print(out=out)
def print(self, final: bool = False, out=sys.stdout) -> None:
if not self.enabled:
return
current_time = self.get_time()
if self.last_printed + self.delta < current_time or final:
self.last_printed = current_time
# print a newline if this is the final update
maybe_newline = "\n" if final else ""
# if we know the total bytes, show a percentage, otherwise a spinner
if self.total_bytes > 0:
percentage = min(100 * self.current_bytes / self.total_bytes, 100.0)
percent_or_spinner = f"[{percentage:3.0f}%] "
else:
# only show the spinner if we are not at 100%
if final:
percent_or_spinner = "[100%] "
else:
percent_or_spinner = f"[ {self.spinner[self.index]} ] "
self.index = (self.index + 1) % len(self.spinner)
print(
f"\r {percent_or_spinner}{_format_bytes(self.current_bytes)} "
f"@ {_format_speed(self.current_bytes, current_time - self.start_time)}"
f"{maybe_newline}",
end="",
flush=True,
file=out,
)
@fetcher @fetcher
class URLFetchStrategy(FetchStrategy): class URLFetchStrategy(FetchStrategy):
"""URLFetchStrategy pulls source code from a URL for an archive, check the """URLFetchStrategy pulls source code from a URL for an archive, check the
@ -316,7 +427,7 @@ def _check_headers(self, headers):
tty.warn(msg) tty.warn(msg)
@_needs_stage @_needs_stage
def _fetch_urllib(self, url): def _fetch_urllib(self, url, chunk_size=65536):
save_file = self.stage.save_filename save_file = self.stage.save_filename
request = urllib.request.Request(url, headers={"User-Agent": web_util.SPACK_USER_AGENT}) request = urllib.request.Request(url, headers={"User-Agent": web_util.SPACK_USER_AGENT})
@ -327,8 +438,15 @@ def _fetch_urllib(self, url):
try: try:
response = web_util.urlopen(request) response = web_util.urlopen(request)
tty.msg(f"Fetching {url}") tty.msg(f"Fetching {url}")
progress = FetchProgress.from_headers(response.headers, enabled=sys.stdout.isatty())
with open(save_file, "wb") as f: with open(save_file, "wb") as f:
shutil.copyfileobj(response, f) while True:
chunk = response.read(chunk_size)
if not chunk:
break
f.write(chunk)
progress.advance(len(chunk))
progress.print(final=True)
except OSError as e: except OSError as e:
# clean up archive on failure. # clean up archive on failure.
if self.archive_file: if self.archive_file:

View File

@ -2,6 +2,8 @@
# #
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
from io import StringIO
import pytest import pytest
from spack import fetch_strategy from spack import fetch_strategy
@ -13,3 +15,136 @@ def test_fetchstrategy_bad_url_scheme():
with pytest.raises(ValueError): with pytest.raises(ValueError):
fetcher = fetch_strategy.from_url_scheme("bogus-scheme://example.com/a/b/c") # noqa: F841 fetcher = fetch_strategy.from_url_scheme("bogus-scheme://example.com/a/b/c") # noqa: F841
@pytest.mark.parametrize(
"expected,total_bytes",
[
(" 0.00 B", 0),
(" 999.00 B", 999),
(" 1.00 KB", 1000),
(" 2.05 KB", 2048),
(" 1.00 MB", 1e6),
(" 12.30 MB", 1.23e7),
(" 1.23 GB", 1.23e9),
(" 999.99 GB", 9.9999e11),
("5000.00 GB", 5e12),
],
)
def test_format_bytes(expected, total_bytes):
assert fetch_strategy._format_bytes(total_bytes) == expected
@pytest.mark.parametrize(
"expected,total_bytes,elapsed",
[
(" 0.0 B/s", 0, 0), # no time passed -- defaults to 1s.
(" 0.0 B/s", 0, 1),
(" 999.0 B/s", 999, 1),
(" 1.0 KB/s", 1000, 1),
(" 500.0 B/s", 1000, 2),
(" 2.0 KB/s", 2048, 1),
(" 1.0 MB/s", 1e6, 1),
(" 500.0 KB/s", 1e6, 2),
(" 12.3 MB/s", 1.23e7, 1),
(" 1.2 GB/s", 1.23e9, 1),
(" 999.9 GB/s", 9.999e11, 1),
("5000.0 GB/s", 5e12, 1),
],
)
def test_format_speed(expected, total_bytes, elapsed):
assert fetch_strategy._format_speed(total_bytes, elapsed) == expected
def test_fetch_progress_unknown_size():
# time stamps in seconds, with 0.1s delta except 1.5 -> 1.55.
time_stamps = iter([1.0, 1.5, 1.55, 2.0, 3.0, 5.0, 5.5, 5.5])
progress = fetch_strategy.FetchProgress(total_bytes=None, get_time=lambda: next(time_stamps))
assert progress.start_time == 1.0
out = StringIO()
progress.advance(1000, out)
assert progress.last_printed == 1.5
progress.advance(50, out)
assert progress.last_printed == 1.5 # does not print, too early after last print
progress.advance(2000, out)
assert progress.last_printed == 2.0
progress.advance(3000, out)
assert progress.last_printed == 3.0
progress.advance(4000, out)
assert progress.last_printed == 5.0
progress.advance(4000, out)
assert progress.last_printed == 5.5
progress.print(final=True, out=out) # finalize download
outputs = [
"\r [ | ] 1.00 KB @ 2.0 KB/s",
"\r [ / ] 3.05 KB @ 3.0 KB/s",
"\r [ - ] 6.05 KB @ 3.0 KB/s",
"\r [ \\ ] 10.05 KB @ 2.5 KB/s", # have to escape \ here but is aligned in output
"\r [ | ] 14.05 KB @ 3.1 KB/s",
"\r [100%] 14.05 KB @ 3.1 KB/s\n", # final print: no spinner; newline
]
assert out.getvalue() == "".join(outputs)
def test_fetch_progress_known_size():
time_stamps = iter([1.0, 1.5, 3.0, 4.0, 4.0])
progress = fetch_strategy.FetchProgress(total_bytes=6000, get_time=lambda: next(time_stamps))
out = StringIO()
progress.advance(1000, out) # time 1.5
progress.advance(2000, out) # time 3.0
progress.advance(3000, out) # time 4.0
progress.print(final=True, out=out)
outputs = [
"\r [ 17%] 1.00 KB @ 2.0 KB/s",
"\r [ 50%] 3.00 KB @ 1.5 KB/s",
"\r [100%] 6.00 KB @ 2.0 KB/s",
"\r [100%] 6.00 KB @ 2.0 KB/s\n", # final print has newline
]
assert out.getvalue() == "".join(outputs)
def test_fetch_progress_disabled():
"""When disabled, FetchProgress shouldn't print anything when advanced"""
def get_time():
raise RuntimeError("Should not be called")
progress = fetch_strategy.FetchProgress(enabled=False, get_time=get_time)
out = StringIO()
progress.advance(1000, out)
progress.advance(2000, out)
progress.print(final=True, out=out)
assert progress.last_printed == 0
assert not out.getvalue()
@pytest.mark.parametrize(
"header,value,total_bytes",
[
("Content-Length", "1234", 1234),
("Content-Length", "0", 0),
("Content-Length", "-10", 0),
("Content-Length", "not a number", 0),
("Not-Content-Length", "1234", 0),
],
)
def test_fetch_progress_from_headers(header, value, total_bytes):
time_stamps = iter([1.0, 1.5, 3.0, 4.0, 4.0])
progress = fetch_strategy.FetchProgress.from_headers(
{header: value}, get_time=lambda: next(time_stamps), enabled=True
)
assert progress.total_bytes == total_bytes
assert progress.enabled
assert progress.start_time == 1.0
def test_fetch_progress_from_headers_disabled():
progress = fetch_strategy.FetchProgress.from_headers(
{"Content-Length": "1234"}, get_time=lambda: 1.0, enabled=False
)
assert not progress.enabled