diff --git a/lib/spack/spack/fetch_strategy.py b/lib/spack/spack/fetch_strategy.py index dd09a6d3dd1..d98b3971221 100644 --- a/lib/spack/spack/fetch_strategy.py +++ b/lib/spack/spack/fetch_strategy.py @@ -27,11 +27,14 @@ import os import re import shutil +import sys +import time import urllib.error import urllib.parse import urllib.request +import urllib.response from pathlib import PurePath -from typing import List, Optional +from typing import Callable, List, Mapping, Optional import llnl.url import llnl.util @@ -219,6 +222,114 @@ def mirror_id(self): """BundlePackages don't have a mirror id.""" +def _format_speed(total_bytes: int, elapsed: float) -> str: + """Return a human-readable average download speed string.""" + elapsed = 1 if elapsed <= 0 else elapsed # avoid divide by zero + speed = total_bytes / elapsed + if speed >= 1e9: + return f"{speed / 1e9:6.1f} GB/s" + elif speed >= 1e6: + return f"{speed / 1e6:6.1f} MB/s" + elif speed >= 1e3: + return f"{speed / 1e3:6.1f} KB/s" + return f"{speed:6.1f} B/s" + + +def _format_bytes(total_bytes: int) -> str: + """Return a human-readable total bytes string.""" + if total_bytes >= 1e9: + return f"{total_bytes / 1e9:7.2f} GB" + elif total_bytes >= 1e6: + return f"{total_bytes / 1e6:7.2f} MB" + elif total_bytes >= 1e3: + return f"{total_bytes / 1e3:7.2f} KB" + return f"{total_bytes:7.2f} B" + + +class FetchProgress: + #: Characters to rotate in the spinner. + spinner = ["|", "/", "-", "\\"] + + def __init__( + self, + total_bytes: Optional[int] = None, + enabled: bool = True, + get_time: Callable[[], float] = time.time, + ) -> None: + """Initialize a FetchProgress instance. + Args: + total_bytes: Total number of bytes to download, if known. + enabled: Whether to print progress information. + get_time: Function to get the current time.""" + #: Number of bytes downloaded so far. + self.current_bytes = 0 + #: Delta time between progress prints + self.delta = 0.1 + #: Whether to print progress information. + self.enabled = enabled + #: Function to get the current time. + self.get_time = get_time + #: Time of last progress print to limit output + self.last_printed = 0.0 + #: Time of start of download + self.start_time = get_time() if enabled else 0.0 + #: Total number of bytes to download, if known. + self.total_bytes = total_bytes if total_bytes and total_bytes > 0 else 0 + #: Index of spinner character to print (used if total bytes is unknown) + self.index = 0 + + @classmethod + def from_headers( + cls, + headers: Mapping[str, str], + enabled: bool = True, + get_time: Callable[[], float] = time.time, + ) -> "FetchProgress": + """Create a FetchProgress instance from HTTP headers.""" + # headers.get is case-insensitive if it's from a HTTPResponse object. + content_length = headers.get("Content-Length") + try: + total_bytes = int(content_length) if content_length else None + except ValueError: + total_bytes = None + return cls(total_bytes=total_bytes, enabled=enabled, get_time=get_time) + + def advance(self, num_bytes: int, out=sys.stdout) -> None: + if not self.enabled: + return + self.current_bytes += num_bytes + self.print(out=out) + + def print(self, final: bool = False, out=sys.stdout) -> None: + if not self.enabled: + return + current_time = self.get_time() + if self.last_printed + self.delta < current_time or final: + self.last_printed = current_time + # print a newline if this is the final update + maybe_newline = "\n" if final else "" + # if we know the total bytes, show a percentage, otherwise a spinner + if self.total_bytes > 0: + percentage = min(100 * self.current_bytes / self.total_bytes, 100.0) + percent_or_spinner = f"[{percentage:3.0f}%] " + else: + # only show the spinner if we are not at 100% + if final: + percent_or_spinner = "[100%] " + else: + percent_or_spinner = f"[ {self.spinner[self.index]} ] " + self.index = (self.index + 1) % len(self.spinner) + + print( + f"\r {percent_or_spinner}{_format_bytes(self.current_bytes)} " + f"@ {_format_speed(self.current_bytes, current_time - self.start_time)}" + f"{maybe_newline}", + end="", + flush=True, + file=out, + ) + + @fetcher class URLFetchStrategy(FetchStrategy): """URLFetchStrategy pulls source code from a URL for an archive, check the @@ -316,7 +427,7 @@ def _check_headers(self, headers): tty.warn(msg) @_needs_stage - def _fetch_urllib(self, url): + def _fetch_urllib(self, url, chunk_size=65536): save_file = self.stage.save_filename request = urllib.request.Request(url, headers={"User-Agent": web_util.SPACK_USER_AGENT}) @@ -327,8 +438,15 @@ def _fetch_urllib(self, url): try: response = web_util.urlopen(request) tty.msg(f"Fetching {url}") + progress = FetchProgress.from_headers(response.headers, enabled=sys.stdout.isatty()) with open(save_file, "wb") as f: - shutil.copyfileobj(response, f) + while True: + chunk = response.read(chunk_size) + if not chunk: + break + f.write(chunk) + progress.advance(len(chunk)) + progress.print(final=True) except OSError as e: # clean up archive on failure. if self.archive_file: diff --git a/lib/spack/spack/test/fetch_strategy.py b/lib/spack/spack/test/fetch_strategy.py index d0451e77c25..f349da03ef1 100644 --- a/lib/spack/spack/test/fetch_strategy.py +++ b/lib/spack/spack/test/fetch_strategy.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: (Apache-2.0 OR MIT) +from io import StringIO + import pytest from spack import fetch_strategy @@ -13,3 +15,136 @@ def test_fetchstrategy_bad_url_scheme(): with pytest.raises(ValueError): fetcher = fetch_strategy.from_url_scheme("bogus-scheme://example.com/a/b/c") # noqa: F841 + + +@pytest.mark.parametrize( + "expected,total_bytes", + [ + (" 0.00 B", 0), + (" 999.00 B", 999), + (" 1.00 KB", 1000), + (" 2.05 KB", 2048), + (" 1.00 MB", 1e6), + (" 12.30 MB", 1.23e7), + (" 1.23 GB", 1.23e9), + (" 999.99 GB", 9.9999e11), + ("5000.00 GB", 5e12), + ], +) +def test_format_bytes(expected, total_bytes): + assert fetch_strategy._format_bytes(total_bytes) == expected + + +@pytest.mark.parametrize( + "expected,total_bytes,elapsed", + [ + (" 0.0 B/s", 0, 0), # no time passed -- defaults to 1s. + (" 0.0 B/s", 0, 1), + (" 999.0 B/s", 999, 1), + (" 1.0 KB/s", 1000, 1), + (" 500.0 B/s", 1000, 2), + (" 2.0 KB/s", 2048, 1), + (" 1.0 MB/s", 1e6, 1), + (" 500.0 KB/s", 1e6, 2), + (" 12.3 MB/s", 1.23e7, 1), + (" 1.2 GB/s", 1.23e9, 1), + (" 999.9 GB/s", 9.999e11, 1), + ("5000.0 GB/s", 5e12, 1), + ], +) +def test_format_speed(expected, total_bytes, elapsed): + assert fetch_strategy._format_speed(total_bytes, elapsed) == expected + + +def test_fetch_progress_unknown_size(): + # time stamps in seconds, with 0.1s delta except 1.5 -> 1.55. + time_stamps = iter([1.0, 1.5, 1.55, 2.0, 3.0, 5.0, 5.5, 5.5]) + progress = fetch_strategy.FetchProgress(total_bytes=None, get_time=lambda: next(time_stamps)) + assert progress.start_time == 1.0 + out = StringIO() + + progress.advance(1000, out) + assert progress.last_printed == 1.5 + progress.advance(50, out) + assert progress.last_printed == 1.5 # does not print, too early after last print + progress.advance(2000, out) + assert progress.last_printed == 2.0 + progress.advance(3000, out) + assert progress.last_printed == 3.0 + progress.advance(4000, out) + assert progress.last_printed == 5.0 + progress.advance(4000, out) + assert progress.last_printed == 5.5 + progress.print(final=True, out=out) # finalize download + + outputs = [ + "\r [ | ] 1.00 KB @ 2.0 KB/s", + "\r [ / ] 3.05 KB @ 3.0 KB/s", + "\r [ - ] 6.05 KB @ 3.0 KB/s", + "\r [ \\ ] 10.05 KB @ 2.5 KB/s", # have to escape \ here but is aligned in output + "\r [ | ] 14.05 KB @ 3.1 KB/s", + "\r [100%] 14.05 KB @ 3.1 KB/s\n", # final print: no spinner; newline + ] + + assert out.getvalue() == "".join(outputs) + + +def test_fetch_progress_known_size(): + time_stamps = iter([1.0, 1.5, 3.0, 4.0, 4.0]) + progress = fetch_strategy.FetchProgress(total_bytes=6000, get_time=lambda: next(time_stamps)) + out = StringIO() + progress.advance(1000, out) # time 1.5 + progress.advance(2000, out) # time 3.0 + progress.advance(3000, out) # time 4.0 + progress.print(final=True, out=out) + + outputs = [ + "\r [ 17%] 1.00 KB @ 2.0 KB/s", + "\r [ 50%] 3.00 KB @ 1.5 KB/s", + "\r [100%] 6.00 KB @ 2.0 KB/s", + "\r [100%] 6.00 KB @ 2.0 KB/s\n", # final print has newline + ] + + assert out.getvalue() == "".join(outputs) + + +def test_fetch_progress_disabled(): + """When disabled, FetchProgress shouldn't print anything when advanced""" + + def get_time(): + raise RuntimeError("Should not be called") + + progress = fetch_strategy.FetchProgress(enabled=False, get_time=get_time) + out = StringIO() + progress.advance(1000, out) + progress.advance(2000, out) + progress.print(final=True, out=out) + assert progress.last_printed == 0 + assert not out.getvalue() + + +@pytest.mark.parametrize( + "header,value,total_bytes", + [ + ("Content-Length", "1234", 1234), + ("Content-Length", "0", 0), + ("Content-Length", "-10", 0), + ("Content-Length", "not a number", 0), + ("Not-Content-Length", "1234", 0), + ], +) +def test_fetch_progress_from_headers(header, value, total_bytes): + time_stamps = iter([1.0, 1.5, 3.0, 4.0, 4.0]) + progress = fetch_strategy.FetchProgress.from_headers( + {header: value}, get_time=lambda: next(time_stamps), enabled=True + ) + assert progress.total_bytes == total_bytes + assert progress.enabled + assert progress.start_time == 1.0 + + +def test_fetch_progress_from_headers_disabled(): + progress = fetch_strategy.FetchProgress.from_headers( + {"Content-Length": "1234"}, get_time=lambda: 1.0, enabled=False + ) + assert not progress.enabled