Merge remote-tracking branch 'upstream/main' into conda-channels

This commit is contained in:
YuviPanda
2023-09-29 14:27:25 -07:00
180 changed files with 7796 additions and 6028 deletions

View File

@@ -1,14 +1,17 @@
"""
Wrap conda commandline program
"""
import contextlib
import hashlib
import json
import logging
import os
import subprocess
import json
import hashlib
import contextlib
import tempfile
import time
import requests
from distutils.version import LooseVersion as V
from tljh import utils
@@ -25,23 +28,21 @@ def sha256_file(fname):
return hash_sha256.hexdigest()
def check_miniconda_version(prefix, version):
"""
Return true if a miniconda install with version exists at prefix
"""
def get_conda_package_versions(prefix):
"""Get conda package versions, via `conda list --json`"""
versions = {}
try:
installed_version = (
subprocess.check_output(
[os.path.join(prefix, "bin", "conda"), "-V"], stderr=subprocess.STDOUT
)
.decode()
.strip()
.split()[1]
out = subprocess.check_output(
[os.path.join(prefix, "bin", "conda"), "list", "--json"],
text=True,
)
return V(installed_version) >= V(version)
except (subprocess.CalledProcessError, FileNotFoundError):
# Conda doesn't exist
return False
return versions
packages = json.loads(out)
for package in packages:
versions[package["name"]] = package["version"]
return versions
@contextlib.contextmanager
@@ -53,14 +54,21 @@ def download_miniconda_installer(installer_url, sha256sum):
of given version, verifies the sha256sum & provides path to it to the `with`
block to run.
"""
with tempfile.NamedTemporaryFile("wb") as f:
f.write(requests.get(installer_url).content)
logger = logging.getLogger("tljh")
logger.info(f"Downloading conda installer {installer_url}")
with tempfile.NamedTemporaryFile("wb", suffix=".sh") as f:
tic = time.perf_counter()
r = requests.get(installer_url)
r.raise_for_status()
f.write(r.content)
# Remain in the NamedTemporaryFile context, but flush changes, see:
# https://docs.python.org/3/library/os.html#os.fsync
f.flush()
os.fsync(f.fileno())
t = time.perf_counter() - tic
logger.info(f"Downloaded conda installer {installer_url} in {t:.1f}s")
if sha256_file(f.name) != sha256sum:
if sha256sum and sha256_file(f.name) != sha256sum:
raise Exception("sha256sum hash mismatch! Downloaded file corrupted")
yield f.name
@@ -90,48 +98,38 @@ def install_miniconda(installer_path, prefix):
fix_permissions(prefix)
def ensure_conda_packages(prefix, packages, channels=('conda-forge',)):
def ensure_conda_packages(prefix, packages, channels=('conda-forge',), force_reinstall=False):
"""
Ensure packages (from channels) are installed in the conda prefix.
Note that conda seem to update dependencies by default, so there is probably
no need to have a update parameter exposed for this function.
"""
conda_executable = [os.path.join(prefix, "bin", "mamba")]
conda_executable = os.path.join(prefix, "bin", "mamba")
if not os.path.isfile(conda_executable):
# fallback on conda if mamba is not present (e.g. for mamba to install itself)
conda_executable = os.path.join(prefix, "bin", "conda")
cmd = [conda_executable, "install", "--yes"]
if force_reinstall:
# use force-reinstall, e.g. for conda/mamba to ensure everything is okay
# avoids problems with RemoveError upgrading conda from old versions
cmd += ["--force-reinstall"]
cmd += ["-c", channel for channel in channels]
abspath = os.path.abspath(prefix)
# Let subprocess errors propagate
# Explicitly do *not* capture stderr, since that's not always JSON!
# Scripting conda is a PITA!
# FIXME: raise different exception when using
channel_cmd = '-c ' + ' -c '.join(channels)
raw_output = subprocess.check_output(
conda_executable
utils.run_subprocess(
cmd
+ [
"install",
"--json",
"--prefix",
abspath,
]
+ channel_cmd.split()
+ packages
).decode()
# `conda install` outputs JSON lines for fetch updates,
# and a undelimited output at the end. There is no reasonable way to
# parse this outside of this kludge.
filtered_output = "\n".join(
[
l
for l in raw_output.split("\n")
# Sometimes the JSON messages start with a \x00. The lstrip removes these.
# conda messages seem to randomly throw \x00 in places for no reason
if not l.lstrip("\x00").startswith('{"fetch"')
]
+ packages,
input="",
)
output = json.loads(filtered_output.lstrip("\x00"))
if "success" in output and output["success"] == True:
return
fix_permissions(prefix)