mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Link with cuDNN
This commit is contained in:
@@ -212,7 +212,7 @@ jobs:
|
||||
name: Install Python package
|
||||
command: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev cudnn9-cuda-12
|
||||
python3 -m venv env
|
||||
source env/bin/activate
|
||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||
|
@@ -15,6 +15,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||
@@ -131,6 +132,23 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
||||
# Use NVRTC and driver APIs.
|
||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||
|
||||
# Download cuDNN-frontend, which is used to find cuDNN. We are not using the
|
||||
# frontend APIs for now, link with "cudnn_frontend" if needed.
|
||||
FetchContent_Declare(
|
||||
cudnn
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||
GIT_TAG v1.12.1
|
||||
GIT_SHALLOW TRUE
|
||||
EXCLUDE_FROM_ALL)
|
||||
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||
FetchContent_MakeAvailable(cudnn)
|
||||
# Link with the actual cuDNN libraries.
|
||||
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||
|
||||
# Suppress nvcc warnings on MLX headers.
|
||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||
--diag_suppress=997>)
|
||||
|
33
mlx/backend/cuda/conv.cpp
Normal file
33
mlx/backend/cuda/conv.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
#define CHECK_CUDNNS_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||
|
||||
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||
if (err != CUDNN_STATUS_SUCCESS) {
|
||||
throw std::runtime_error(
|
||||
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||
}
|
||||
}
|
||||
|
||||
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
assert(inputs.size() == 2);
|
||||
const auto& in = inputs[0];
|
||||
const auto& wt = inputs[1];
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
throw std::runtime_error("NYI");
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
@@ -41,9 +41,12 @@ Device::Device(int device) : device_(device) {
|
||||
// The cublasLt handle is used by matmul.
|
||||
make_current();
|
||||
cublasLtCreate(<_);
|
||||
// The cudnn handle is used by Convolution.
|
||||
cudnnCreate(&cudnn_);
|
||||
}
|
||||
|
||||
Device::~Device() {
|
||||
cudnnDestroy(cudnn_);
|
||||
cublasLtDestroy(lt_);
|
||||
}
|
||||
|
||||
|
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cuda.h>
|
||||
#include <cudnn.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
|
||||
#include <unordered_map>
|
||||
@@ -137,12 +138,16 @@ class Device {
|
||||
cublasLtHandle_t lt_handle() const {
|
||||
return lt_;
|
||||
}
|
||||
cudnnHandle_t cudnn_handle() const {
|
||||
return cudnn_;
|
||||
}
|
||||
|
||||
private:
|
||||
int device_;
|
||||
int compute_capability_major_;
|
||||
int compute_capability_minor_;
|
||||
cublasLtHandle_t lt_;
|
||||
cudnnHandle_t cudnn_;
|
||||
std::unordered_map<int, CommandEncoder> encoders_;
|
||||
};
|
||||
|
||||
|
@@ -71,7 +71,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
||||
}
|
||||
|
||||
NO_GPU(BlockMaskedMM)
|
||||
NO_GPU(Convolution)
|
||||
NO_GPU(DynamicSlice)
|
||||
NO_GPU(DynamicSliceUpdate)
|
||||
NO_GPU(FFT)
|
||||
|
Reference in New Issue
Block a user