mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Link with cuDNN
This commit is contained in:
@@ -212,7 +212,7 @@ jobs:
|
|||||||
name: Install Python package
|
name: Install Python package
|
||||||
command: |
|
command: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev cudnn9-cuda-12
|
||||||
python3 -m venv env
|
python3 -m venv env
|
||||||
source env/bin/activate
|
source env/bin/activate
|
||||||
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
|
||||||
@@ -131,6 +132,23 @@ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
|
|||||||
# Use NVRTC and driver APIs.
|
# Use NVRTC and driver APIs.
|
||||||
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
|
||||||
|
|
||||||
|
# Download cuDNN-frontend, which is used to find cuDNN. We are not using the
|
||||||
|
# frontend APIs for now, link with "cudnn_frontend" if needed.
|
||||||
|
FetchContent_Declare(
|
||||||
|
cudnn
|
||||||
|
GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
|
||||||
|
GIT_TAG v1.12.1
|
||||||
|
GIT_SHALLOW TRUE
|
||||||
|
EXCLUDE_FROM_ALL)
|
||||||
|
set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_TESTS OFF)
|
||||||
|
set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
|
||||||
|
FetchContent_MakeAvailable(cudnn)
|
||||||
|
# Link with the actual cuDNN libraries.
|
||||||
|
include(${cudnn_frontend_SOURCE_DIR}/cmake/cuDNN.cmake)
|
||||||
|
target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
|
||||||
|
|
||||||
# Suppress nvcc warnings on MLX headers.
|
# Suppress nvcc warnings on MLX headers.
|
||||||
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
target_compile_options(mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe
|
||||||
--diag_suppress=997>)
|
--diag_suppress=997>)
|
||||||
|
|||||||
33
mlx/backend/cuda/conv.cpp
Normal file
33
mlx/backend/cuda/conv.cpp
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <fmt/format.h>
|
||||||
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
#define CHECK_CUDNNS_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
|
||||||
|
|
||||||
|
void check_cudnn_error(const char* name, cudnnStatus_t err) {
|
||||||
|
if (err != CUDNN_STATUS_SUCCESS) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
fmt::format("{} failed: {}.", name, cudnnGetErrorString(err)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
nvtx3::scoped_range r("Convolution::eval_gpu");
|
||||||
|
auto& s = stream();
|
||||||
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
|
|
||||||
|
assert(inputs.size() == 2);
|
||||||
|
const auto& in = inputs[0];
|
||||||
|
const auto& wt = inputs[1];
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
|
throw std::runtime_error("NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
@@ -41,9 +41,12 @@ Device::Device(int device) : device_(device) {
|
|||||||
// The cublasLt handle is used by matmul.
|
// The cublasLt handle is used by matmul.
|
||||||
make_current();
|
make_current();
|
||||||
cublasLtCreate(<_);
|
cublasLtCreate(<_);
|
||||||
|
// The cudnn handle is used by Convolution.
|
||||||
|
cudnnCreate(&cudnn_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
|
cudnnDestroy(cudnn_);
|
||||||
cublasLtDestroy(lt_);
|
cublasLtDestroy(lt_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include <cublasLt.h>
|
#include <cublasLt.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
|
#include <cudnn.h>
|
||||||
#include <thrust/execution_policy.h>
|
#include <thrust/execution_policy.h>
|
||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@@ -137,12 +138,16 @@ class Device {
|
|||||||
cublasLtHandle_t lt_handle() const {
|
cublasLtHandle_t lt_handle() const {
|
||||||
return lt_;
|
return lt_;
|
||||||
}
|
}
|
||||||
|
cudnnHandle_t cudnn_handle() const {
|
||||||
|
return cudnn_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_;
|
int device_;
|
||||||
int compute_capability_major_;
|
int compute_capability_major_;
|
||||||
int compute_capability_minor_;
|
int compute_capability_minor_;
|
||||||
cublasLtHandle_t lt_;
|
cublasLtHandle_t lt_;
|
||||||
|
cudnnHandle_t cudnn_;
|
||||||
std::unordered_map<int, CommandEncoder> encoders_;
|
std::unordered_map<int, CommandEncoder> encoders_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,6 @@ bool fast::ScaledDotProductAttention::use_fallback(
|
|||||||
}
|
}
|
||||||
|
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(Convolution)
|
|
||||||
NO_GPU(DynamicSlice)
|
NO_GPU(DynamicSlice)
|
||||||
NO_GPU(DynamicSliceUpdate)
|
NO_GPU(DynamicSliceUpdate)
|
||||||
NO_GPU(FFT)
|
NO_GPU(FFT)
|
||||||
|
|||||||
Reference in New Issue
Block a user