Install linux with mlx[cuda] and mlx[cpu] (#2356)

* install linux with mlx[cuda] and mlx[cpu]

* temp for testing

* cleanup circle, fix cuda repair

* update circle

* update circle

* decouple python bindings from core libraries
This commit is contained in:
Awni Hannun
2025-07-14 17:17:33 -07:00
committed by GitHub
parent 49114f28ab
commit f0a0b077a0
8 changed files with 264 additions and 212 deletions

View File

@@ -1,17 +1,23 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--plat manylinux_2_39_x86_64 \
--exclude libcublas* \
--exclude libnvrtc*
--exclude libnvrtc* \
-w wheel_tmp
cd wheelhouse
mkdir wheelhouse
cd wheel_tmp
repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit)
rpath=$(patchelf --print-rpath "${core_so}")
rpath=$rpath:\$ORIGIN/../nvidia/cublas/lib:\$ORIGIN/../nvidia/cuda_nvrtc/lib
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
rm "${repaired_wheel}"
mlx_so="mlx/lib/libmlx.so"
rpath=$(patchelf --print-rpath "${mlx_so}")
base="\$ORIGIN/../../nvidia"
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
python ../python/scripts/repair_record.py ${mlx_so}
# Re-zip the repaired wheel
zip -r -q "${repaired_wheel}" .
zip -r -q "../wheelhouse/${repaired_wheel}" .

View File

@@ -0,0 +1,19 @@
#!/bin/bash
auditwheel repair dist/* \
--plat manylinux_2_35_x86_64 \
--exclude libmlx* \
-w wheel_tmp
mkdir wheelhouse
cd wheel_tmp
repaired_wheel=$(find . -name "*.whl" -print -quit)
unzip -q "${repaired_wheel}"
rm "${repaired_wheel}"
core_so=$(find mlx -name "core*.so" -print -quit)
rpath="\$ORIGIN/lib"
patchelf --force-rpath --set-rpath "$rpath" "$core_so"
python ../python/scripts/repair_record.py ${core_so}
# Re-zip the repaired wheel
zip -r -q "../wheelhouse/${repaired_wheel}" .

View File

@@ -0,0 +1,33 @@
import base64
import glob
import hashlib
import sys
filename = sys.argv[1]
# Compute the new hash and size
def urlsafe_b64encode(data: bytes) -> bytes:
return base64.urlsafe_b64encode(data).rstrip(b"=")
hasher = hashlib.sha256()
with open(filename, "rb") as f:
data = f.read()
hasher.update(data)
hash_str = urlsafe_b64encode(hasher.digest()).decode("ascii")
size = len(data)
# Update the record file
record_file = glob.glob("*/RECORD")[0]
with open(record_file, "r") as f:
lines = [l.split(",") for l in f.readlines()]
for l in lines:
if filename == l[0]:
l[1] = hash_str
l[2] = f"{size}\n"
with open(record_file, "w") as f:
for l in lines:
f.write(",".join(l))