mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
[CUDA] Initial implementation of Convolution with cuDNN (#2385)
* Link with cuDNN * Initial implementation * Remove backend apis * Fix recording cudnn conv * More unused backend apis * Fix C++ conv tests * include cudnn as python dep * Install libcudnn9-dev-cuda-12 in CI * cudnn only accepts contiguous inputs * Switch to backend apis * Plan needs to be kept alive * Turn off tf32 * Add cache * Test the native cuda graph api * Set cudnn stream before execution * Make LRUCache more like a normal container * Do error check for cublas handle * Zero-initilizing array * Use tf32 for conv * Skip TestConv.test_torch_conv_2D test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -216,6 +216,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
@@ -385,7 +386,7 @@ jobs:
|
||||
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
sudo apt-get update
|
||||
sudo apt install cuda-toolkit-12-9
|
||||
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install zip
|
||||
pip install auditwheel
|
||||
|
Reference in New Issue
Block a user