mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
include cudnn as python dep
This commit is contained in:
@@ -16,7 +16,7 @@ rm "${repaired_wheel}"
|
|||||||
mlx_so="mlx/lib/libmlx.so"
|
mlx_so="mlx/lib/libmlx.so"
|
||||||
rpath=$(patchelf --print-rpath "${mlx_so}")
|
rpath=$(patchelf --print-rpath "${mlx_so}")
|
||||||
base="\$ORIGIN/../../nvidia"
|
base="\$ORIGIN/../../nvidia"
|
||||||
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib
|
rpath=$rpath:${base}/cublas/lib:${base}/cuda_nvrtc/lib:${base}/cudnn/lib
|
||||||
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
patchelf --force-rpath --set-rpath "$rpath" "$mlx_so"
|
||||||
python ../python/scripts/repair_record.py ${mlx_so}
|
python ../python/scripts/repair_record.py ${mlx_so}
|
||||||
|
|
||||||
|
|||||||
1
setup.py
1
setup.py
@@ -289,6 +289,7 @@ if __name__ == "__main__":
|
|||||||
install_requires += [
|
install_requires += [
|
||||||
"nvidia-cublas-cu12==12.9.*",
|
"nvidia-cublas-cu12==12.9.*",
|
||||||
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
"nvidia-cuda-nvrtc-cu12==12.9.*",
|
||||||
|
"nvidia-cudnn-cu12==12.9.*",
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
name = "mlx-cpu"
|
name = "mlx-cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user