fetch_strategy.py: show progress (#50003)
Show progress meter for fetches when `stdout` is a `tty`. * fetch_strategy.py: show progress * "Fetched: x MB at y MB/s" * add tests, show % if content-length
This commit is contained in:
parent
8fc1ccc686
commit
cc3d40d9d3
@ -27,11 +27,14 @@
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
import urllib.response
|
||||
from pathlib import PurePath
|
||||
from typing import List, Optional
|
||||
from typing import Callable, List, Mapping, Optional
|
||||
|
||||
import llnl.url
|
||||
import llnl.util
|
||||
@ -219,6 +222,114 @@ def mirror_id(self):
|
||||
"""BundlePackages don't have a mirror id."""
|
||||
|
||||
|
||||
def _format_speed(total_bytes: int, elapsed: float) -> str:
|
||||
"""Return a human-readable average download speed string."""
|
||||
elapsed = 1 if elapsed <= 0 else elapsed # avoid divide by zero
|
||||
speed = total_bytes / elapsed
|
||||
if speed >= 1e9:
|
||||
return f"{speed / 1e9:6.1f} GB/s"
|
||||
elif speed >= 1e6:
|
||||
return f"{speed / 1e6:6.1f} MB/s"
|
||||
elif speed >= 1e3:
|
||||
return f"{speed / 1e3:6.1f} KB/s"
|
||||
return f"{speed:6.1f} B/s"
|
||||
|
||||
|
||||
def _format_bytes(total_bytes: int) -> str:
|
||||
"""Return a human-readable total bytes string."""
|
||||
if total_bytes >= 1e9:
|
||||
return f"{total_bytes / 1e9:7.2f} GB"
|
||||
elif total_bytes >= 1e6:
|
||||
return f"{total_bytes / 1e6:7.2f} MB"
|
||||
elif total_bytes >= 1e3:
|
||||
return f"{total_bytes / 1e3:7.2f} KB"
|
||||
return f"{total_bytes:7.2f} B"
|
||||
|
||||
|
||||
class FetchProgress:
|
||||
#: Characters to rotate in the spinner.
|
||||
spinner = ["|", "/", "-", "\\"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
total_bytes: Optional[int] = None,
|
||||
enabled: bool = True,
|
||||
get_time: Callable[[], float] = time.time,
|
||||
) -> None:
|
||||
"""Initialize a FetchProgress instance.
|
||||
Args:
|
||||
total_bytes: Total number of bytes to download, if known.
|
||||
enabled: Whether to print progress information.
|
||||
get_time: Function to get the current time."""
|
||||
#: Number of bytes downloaded so far.
|
||||
self.current_bytes = 0
|
||||
#: Delta time between progress prints
|
||||
self.delta = 0.1
|
||||
#: Whether to print progress information.
|
||||
self.enabled = enabled
|
||||
#: Function to get the current time.
|
||||
self.get_time = get_time
|
||||
#: Time of last progress print to limit output
|
||||
self.last_printed = 0.0
|
||||
#: Time of start of download
|
||||
self.start_time = get_time() if enabled else 0.0
|
||||
#: Total number of bytes to download, if known.
|
||||
self.total_bytes = total_bytes if total_bytes and total_bytes > 0 else 0
|
||||
#: Index of spinner character to print (used if total bytes is unknown)
|
||||
self.index = 0
|
||||
|
||||
@classmethod
|
||||
def from_headers(
|
||||
cls,
|
||||
headers: Mapping[str, str],
|
||||
enabled: bool = True,
|
||||
get_time: Callable[[], float] = time.time,
|
||||
) -> "FetchProgress":
|
||||
"""Create a FetchProgress instance from HTTP headers."""
|
||||
# headers.get is case-insensitive if it's from a HTTPResponse object.
|
||||
content_length = headers.get("Content-Length")
|
||||
try:
|
||||
total_bytes = int(content_length) if content_length else None
|
||||
except ValueError:
|
||||
total_bytes = None
|
||||
return cls(total_bytes=total_bytes, enabled=enabled, get_time=get_time)
|
||||
|
||||
def advance(self, num_bytes: int, out=sys.stdout) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
self.current_bytes += num_bytes
|
||||
self.print(out=out)
|
||||
|
||||
def print(self, final: bool = False, out=sys.stdout) -> None:
|
||||
if not self.enabled:
|
||||
return
|
||||
current_time = self.get_time()
|
||||
if self.last_printed + self.delta < current_time or final:
|
||||
self.last_printed = current_time
|
||||
# print a newline if this is the final update
|
||||
maybe_newline = "\n" if final else ""
|
||||
# if we know the total bytes, show a percentage, otherwise a spinner
|
||||
if self.total_bytes > 0:
|
||||
percentage = min(100 * self.current_bytes / self.total_bytes, 100.0)
|
||||
percent_or_spinner = f"[{percentage:3.0f}%] "
|
||||
else:
|
||||
# only show the spinner if we are not at 100%
|
||||
if final:
|
||||
percent_or_spinner = "[100%] "
|
||||
else:
|
||||
percent_or_spinner = f"[ {self.spinner[self.index]} ] "
|
||||
self.index = (self.index + 1) % len(self.spinner)
|
||||
|
||||
print(
|
||||
f"\r {percent_or_spinner}{_format_bytes(self.current_bytes)} "
|
||||
f"@ {_format_speed(self.current_bytes, current_time - self.start_time)}"
|
||||
f"{maybe_newline}",
|
||||
end="",
|
||||
flush=True,
|
||||
file=out,
|
||||
)
|
||||
|
||||
|
||||
@fetcher
|
||||
class URLFetchStrategy(FetchStrategy):
|
||||
"""URLFetchStrategy pulls source code from a URL for an archive, check the
|
||||
@ -316,7 +427,7 @@ def _check_headers(self, headers):
|
||||
tty.warn(msg)
|
||||
|
||||
@_needs_stage
|
||||
def _fetch_urllib(self, url):
|
||||
def _fetch_urllib(self, url, chunk_size=65536):
|
||||
save_file = self.stage.save_filename
|
||||
|
||||
request = urllib.request.Request(url, headers={"User-Agent": web_util.SPACK_USER_AGENT})
|
||||
@ -327,8 +438,15 @@ def _fetch_urllib(self, url):
|
||||
try:
|
||||
response = web_util.urlopen(request)
|
||||
tty.msg(f"Fetching {url}")
|
||||
progress = FetchProgress.from_headers(response.headers, enabled=sys.stdout.isatty())
|
||||
with open(save_file, "wb") as f:
|
||||
shutil.copyfileobj(response, f)
|
||||
while True:
|
||||
chunk = response.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
progress.advance(len(chunk))
|
||||
progress.print(final=True)
|
||||
except OSError as e:
|
||||
# clean up archive on failure.
|
||||
if self.archive_file:
|
||||
|
@ -2,6 +2,8 @@
|
||||
#
|
||||
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
|
||||
|
||||
from io import StringIO
|
||||
|
||||
import pytest
|
||||
|
||||
from spack import fetch_strategy
|
||||
@ -13,3 +15,136 @@ def test_fetchstrategy_bad_url_scheme():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
fetcher = fetch_strategy.from_url_scheme("bogus-scheme://example.com/a/b/c") # noqa: F841
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected,total_bytes",
|
||||
[
|
||||
(" 0.00 B", 0),
|
||||
(" 999.00 B", 999),
|
||||
(" 1.00 KB", 1000),
|
||||
(" 2.05 KB", 2048),
|
||||
(" 1.00 MB", 1e6),
|
||||
(" 12.30 MB", 1.23e7),
|
||||
(" 1.23 GB", 1.23e9),
|
||||
(" 999.99 GB", 9.9999e11),
|
||||
("5000.00 GB", 5e12),
|
||||
],
|
||||
)
|
||||
def test_format_bytes(expected, total_bytes):
|
||||
assert fetch_strategy._format_bytes(total_bytes) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"expected,total_bytes,elapsed",
|
||||
[
|
||||
(" 0.0 B/s", 0, 0), # no time passed -- defaults to 1s.
|
||||
(" 0.0 B/s", 0, 1),
|
||||
(" 999.0 B/s", 999, 1),
|
||||
(" 1.0 KB/s", 1000, 1),
|
||||
(" 500.0 B/s", 1000, 2),
|
||||
(" 2.0 KB/s", 2048, 1),
|
||||
(" 1.0 MB/s", 1e6, 1),
|
||||
(" 500.0 KB/s", 1e6, 2),
|
||||
(" 12.3 MB/s", 1.23e7, 1),
|
||||
(" 1.2 GB/s", 1.23e9, 1),
|
||||
(" 999.9 GB/s", 9.999e11, 1),
|
||||
("5000.0 GB/s", 5e12, 1),
|
||||
],
|
||||
)
|
||||
def test_format_speed(expected, total_bytes, elapsed):
|
||||
assert fetch_strategy._format_speed(total_bytes, elapsed) == expected
|
||||
|
||||
|
||||
def test_fetch_progress_unknown_size():
|
||||
# time stamps in seconds, with 0.1s delta except 1.5 -> 1.55.
|
||||
time_stamps = iter([1.0, 1.5, 1.55, 2.0, 3.0, 5.0, 5.5, 5.5])
|
||||
progress = fetch_strategy.FetchProgress(total_bytes=None, get_time=lambda: next(time_stamps))
|
||||
assert progress.start_time == 1.0
|
||||
out = StringIO()
|
||||
|
||||
progress.advance(1000, out)
|
||||
assert progress.last_printed == 1.5
|
||||
progress.advance(50, out)
|
||||
assert progress.last_printed == 1.5 # does not print, too early after last print
|
||||
progress.advance(2000, out)
|
||||
assert progress.last_printed == 2.0
|
||||
progress.advance(3000, out)
|
||||
assert progress.last_printed == 3.0
|
||||
progress.advance(4000, out)
|
||||
assert progress.last_printed == 5.0
|
||||
progress.advance(4000, out)
|
||||
assert progress.last_printed == 5.5
|
||||
progress.print(final=True, out=out) # finalize download
|
||||
|
||||
outputs = [
|
||||
"\r [ | ] 1.00 KB @ 2.0 KB/s",
|
||||
"\r [ / ] 3.05 KB @ 3.0 KB/s",
|
||||
"\r [ - ] 6.05 KB @ 3.0 KB/s",
|
||||
"\r [ \\ ] 10.05 KB @ 2.5 KB/s", # have to escape \ here but is aligned in output
|
||||
"\r [ | ] 14.05 KB @ 3.1 KB/s",
|
||||
"\r [100%] 14.05 KB @ 3.1 KB/s\n", # final print: no spinner; newline
|
||||
]
|
||||
|
||||
assert out.getvalue() == "".join(outputs)
|
||||
|
||||
|
||||
def test_fetch_progress_known_size():
|
||||
time_stamps = iter([1.0, 1.5, 3.0, 4.0, 4.0])
|
||||
progress = fetch_strategy.FetchProgress(total_bytes=6000, get_time=lambda: next(time_stamps))
|
||||
out = StringIO()
|
||||
progress.advance(1000, out) # time 1.5
|
||||
progress.advance(2000, out) # time 3.0
|
||||
progress.advance(3000, out) # time 4.0
|
||||
progress.print(final=True, out=out)
|
||||
|
||||
outputs = [
|
||||
"\r [ 17%] 1.00 KB @ 2.0 KB/s",
|
||||
"\r [ 50%] 3.00 KB @ 1.5 KB/s",
|
||||
"\r [100%] 6.00 KB @ 2.0 KB/s",
|
||||
"\r [100%] 6.00 KB @ 2.0 KB/s\n", # final print has newline
|
||||
]
|
||||
|
||||
assert out.getvalue() == "".join(outputs)
|
||||
|
||||
|
||||
def test_fetch_progress_disabled():
|
||||
"""When disabled, FetchProgress shouldn't print anything when advanced"""
|
||||
|
||||
def get_time():
|
||||
raise RuntimeError("Should not be called")
|
||||
|
||||
progress = fetch_strategy.FetchProgress(enabled=False, get_time=get_time)
|
||||
out = StringIO()
|
||||
progress.advance(1000, out)
|
||||
progress.advance(2000, out)
|
||||
progress.print(final=True, out=out)
|
||||
assert progress.last_printed == 0
|
||||
assert not out.getvalue()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"header,value,total_bytes",
|
||||
[
|
||||
("Content-Length", "1234", 1234),
|
||||
("Content-Length", "0", 0),
|
||||
("Content-Length", "-10", 0),
|
||||
("Content-Length", "not a number", 0),
|
||||
("Not-Content-Length", "1234", 0),
|
||||
],
|
||||
)
|
||||
def test_fetch_progress_from_headers(header, value, total_bytes):
|
||||
time_stamps = iter([1.0, 1.5, 3.0, 4.0, 4.0])
|
||||
progress = fetch_strategy.FetchProgress.from_headers(
|
||||
{header: value}, get_time=lambda: next(time_stamps), enabled=True
|
||||
)
|
||||
assert progress.total_bytes == total_bytes
|
||||
assert progress.enabled
|
||||
assert progress.start_time == 1.0
|
||||
|
||||
|
||||
def test_fetch_progress_from_headers_disabled():
|
||||
progress = fetch_strategy.FetchProgress.from_headers(
|
||||
{"Content-Length": "1234"}, get_time=lambda: 1.0, enabled=False
|
||||
)
|
||||
assert not progress.enabled
|
||||
|
Loading…
Reference in New Issue
Block a user