mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	[CUDA] Initial implementation of Convolution with cuDNN (#2385)
* Link with cuDNN * Initial implementation * Remove backend apis * Fix recording cudnn conv * More unused backend apis * Fix C++ conv tests * include cudnn as python dep * Install libcudnn9-dev-cuda-12 in CI * cudnn only accepts contiguous inputs * Switch to backend apis * Plan needs to be kept alive * Turn off tf32 * Add cache * Test the native cuda graph api * Set cudnn stream before execution * Make LRUCache more like a normal container * Do error check for cublas handle * Zero-initilizing array * Use tf32 for conv * Skip TestConv.test_torch_conv_2D test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		| @@ -216,6 +216,7 @@ jobs: | ||||
|           name: Install Python package | ||||
|           command: | | ||||
|             sudo apt-get update | ||||
|             sudo apt-get install libcudnn9-dev-cuda-12 | ||||
|             sudo apt-get install libblas-dev liblapack-dev liblapacke-dev | ||||
|             python3 -m venv env | ||||
|             source env/bin/activate | ||||
| @@ -385,7 +386,7 @@ jobs: | ||||
|             wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb | ||||
|             sudo dpkg -i cuda-keyring_1.1-1_all.deb | ||||
|             sudo apt-get update | ||||
|             sudo apt install cuda-toolkit-12-9 | ||||
|             sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12 | ||||
|             sudo apt-get install libblas-dev liblapack-dev liblapacke-dev | ||||
|             sudo apt-get install zip | ||||
|             pip install auditwheel | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Cheng
					Cheng