Link with cuDNN

This commit is contained in:
Cheng
2025-07-17 01:34:12 +00:00
parent d1f4d291e8
commit 04bd515370
6 changed files with 60 additions and 2 deletions

View File

@@ -212,7 +212,7 @@ jobs:
name: Install Python package
command: |
sudo apt-get update
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev cudnn9-cuda-12
python3 -m venv env
source env/bin/activate
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \

View File

@@ -15,6 +15,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
@@ -131,6 +132,23 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
# Use NVRTC and driver APIs.
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
# Download cuDNN-frontend, which is used to find cuDNN. We are not using the
# frontend APIs for now, link with "cudnn_frontend" if needed.
FetchContent_Declare(
cudnn
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
GIT_TAG v1.12.1
GIT_SHALLOW TRUE
EXCLUDE_FROM_ALL)
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
FetchContent_MakeAvailable(cudnn)
# Link with the actual cuDNN libraries.
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
# Suppress nvcc warnings on MLX headers.
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
--diag_suppress=997>)

33
mlx/backend/cuda/conv.cpp Normal file
View File

@@ -0,0 +1,33 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/primitives.h"
#include <fmt/format.h>
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
#define CHECK_CUDNNS_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
void check_cudnn_error(const char* name, cudnnStatus_t err) {
if (err != CUDNN_STATUS_SUCCESS) {
throw std::runtime_error(
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
}
}
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Convolution::eval_gpu");
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
assert(inputs.size() == 2);
const auto& in = inputs[0];
const auto& wt = inputs[1];
out.set_data(allocator::malloc(out.nbytes()));
throw std::runtime_error("NYI");
}
} // namespace mlx::core

View File

@@ -41,9 +41,12 @@ Device::Device(int device) : device_(device) {
// The cublasLt handle is used by matmul.
make_current();
cublasLtCreate(&lt_);
// The cudnn handle is used by Convolution.
cudnnCreate(&cudnn_);
}
Device::~Device() {
cudnnDestroy(cudnn_);
cublasLtDestroy(lt_);
}

View File

@@ -8,6 +8,7 @@
#include <cublasLt.h>
#include <cuda.h>
#include <cudnn.h>
#include <thrust/execution_policy.h>
#include <unordered_map>
@@ -137,12 +138,16 @@ class Device {
cublasLtHandle_t lt_handle() const {
return lt_;
}
cudnnHandle_t cudnn_handle() const {
return cudnn_;
}
private:
int device_;
int compute_capability_major_;
int compute_capability_minor_;
cublasLtHandle_t lt_;
cudnnHandle_t cudnn_;
std::unordered_map<int, CommandEncoder> encoders_;
};

View File

@@ -71,7 +71,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
}
NO_GPU(BlockMaskedMM)
NO_GPU(Convolution)
NO_GPU(DynamicSlice)
NO_GPU(DynamicSliceUpdate)
NO_GPU(FFT)