From 6571df6ad76eb30d92757364058f24938d03036d Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 17 Jul 2025 23:46:09 +0000 Subject: [PATCH] Remove backend apis --- .circleci/config.yml | 2 +- mlx/backend/cuda/conv.cpp | 154 +++++------------------------------- mlx/backend/cuda/device.cpp | 24 ++++-- tests/ops_tests.cpp | 22 +++++- 4 files changed, 57 insertions(+), 145 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c551e1486..38c993401 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -212,7 +212,7 @@ jobs: name: Install Python package command: | sudo apt-get update - sudo apt-get install libblas-dev liblapack-dev liblapacke-dev cudnn9-cuda-12 + sudo apt-get install libblas-dev liblapack-dev liblapacke-dev python3 -m venv env source env/bin/activate CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \ diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index fcdd57555..3c5d4e532 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -14,30 +14,17 @@ #include #include -// cudnn_frontend.h redefines this macro. -#undef CHECK_CUDNN_ERROR -#undef CHECK_CUDNN_FRONTEND_ERROR - namespace mlx::core { namespace cu { using namespace cudnn_frontend; -#define CHECK_CUDNN_FRONTEND_ERROR(cmd) \ +#define CHECK_CUDNN_FE_ERROR(cmd) \ if (cmd.is_bad()) { \ throw std::runtime_error(fmt::format("{} failed.", #cmd)); \ } -#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) - -void check_cudnn_error(const char* name, cudnnStatus_t err) { - if (err != CUDNN_STATUS_SUCCESS) { - throw std::runtime_error( - fmt::format("{} failed: {}.", name, cudnnGetErrorString(err))); - } -} - auto swapaxes(const array& in, int axis1, int axis2) { std::vector axes(in.ndim()); std::iota(axes.begin(), axes.end(), 0); @@ -63,7 +50,8 @@ class Convolution { const std::vector& output_shape, const std::vector& output_strides, const std::vector& stride, - const std::vector& padding, + const std::vector& padding_lo, + const std::vector& padding_hi, const std::vector& dilation, int groups) : handle_(device.cudnn_handle()) { @@ -73,119 +61,31 @@ class Convolution { graph_.set_io_data_type(cudnn_type) .set_compute_data_type(is_half ? DataType_t::FLOAT : cudnn_type); input_attr_ = graph_.tensor(graph::Tensor_attributes() + .set_name("input") .set_dim(input_shape) .set_stride(input_strides)); filter_attr_ = graph_.tensor(graph::Tensor_attributes() + .set_name("filter") .set_dim(filter_shape) .set_stride(filter_strides)); auto conv_options = graph::Conv_fprop_attributes() - .set_padding(padding) + .set_pre_padding(padding_lo) + .set_post_padding(padding_hi) .set_stride(stride) .set_dilation(dilation); output_attr_ = graph_.conv_fprop(input_attr_, filter_attr_, conv_options); output_attr_->set_output(true); + output_attr_->set_data_type(cudnn_type); + output_attr_->set_dim(output_shape); + output_attr_->set_stride(output_strides); - CHECK_CUDNN_FRONTEND_ERROR(graph_.validate()); - CHECK_CUDNN_FRONTEND_ERROR(graph_.build_operation_graph(handle_)); - CHECK_CUDNN_FRONTEND_ERROR(graph_.create_execution_plans({HeurMode_t::A})); - CHECK_CUDNN_FRONTEND_ERROR(graph_.check_support(handle_)); - CHECK_CUDNN_FRONTEND_ERROR(graph_.build_plans(handle_)); - -#if 0 - int ndim = input_shape.size(); - CHECK_CUDNN_ERROR(cudnnCreateTensorDescriptor(&input_desc_)); - CHECK_CUDNN_ERROR(cudnnSetTensorNdDescriptor( - input_desc_, - cudnn_type, - ndim, - input_shape.data(), - input_strides.data())); - - CHECK_CUDNN_ERROR(cudnnCreateFilterDescriptor(&filter_desc_)); - CHECK_CUDNN_ERROR(cudnnSetFilterNdDescriptor( - filter_desc_, - cudnn_type, - CUDNN_TENSOR_NCHW, - ndim, - filter_shape.data())); - - CHECK_CUDNN_ERROR(cudnnCreateTensorDescriptor(&output_desc_)); - CHECK_CUDNN_ERROR(cudnnSetTensorNdDescriptor( - output_desc_, - cudnn_type, - ndim, - output_shape.data(), - output_strides.data())); - - CHECK_CUDNN_ERROR(cudnnCreateConvolutionDescriptor(&conv_desc_)); - CHECK_CUDNN_ERROR(cudnnSetConvolutionGroupCount(conv_desc_, groups)); - CHECK_CUDNN_ERROR(cudnnSetConvolutionNdDescriptor( - conv_desc_, - ndim - 2, - padding.data(), - stride.data(), - dilation.data(), - CUDNN_CROSS_CORRELATION, - is_half ? CUDNN_DATA_FLOAT : cudnn_type)); - if (is_half) { - CHECK_CUDNN_ERROR( - cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); - } else if (dtype == float32) { - CHECK_CUDNN_ERROR( - cudnnSetConvolutionMathType(conv_desc_, CUDNN_FMA_MATH)); - } else { - CHECK_CUDNN_ERROR( - cudnnSetConvolutionMathType(conv_desc_, CUDNN_DEFAULT_MATH)); - } - - std::vector expected_output_shape(ndim); - CHECK_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim( - conv_desc_, - input_desc_, - filter_desc_, - ndim, - expected_output_shape.data())); - std::cout << "expected_output_shape: " << expected_output_shape - << std::endl; - - cudnnConvolutionFwdAlgoPerf_t results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; - int count; - CHECK_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( - handle_, - input_desc_, - filter_desc_, - conv_desc_, - output_desc_, - std::size(results), - &count, - results)); - for (int i = 0; i < count; ++i) { - if (results[i].status == CUDNN_STATUS_SUCCESS) { - algo_ = results[i].algo; - std::cout << "Found algorithm" << std::endl; - break; - } - } - - CHECK_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize( - handle_, - input_desc_, - filter_desc_, - conv_desc_, - output_desc_, - algo_, - &workspace_size_)); -#endif - } - - ~Convolution() { -#if 0 - cudnnDestroyTensorDescriptor(input_desc_); - cudnnDestroyFilterDescriptor(filter_desc_); - cudnnDestroyTensorDescriptor(output_desc_); - cudnnDestroyConvolutionDescriptor(conv_desc_); -#endif + CHECK_CUDNN_FE_ERROR(graph_.validate()); + CHECK_CUDNN_FE_ERROR(graph_.build_operation_graph(handle_)); + CHECK_CUDNN_FE_ERROR(graph_.create_execution_plans({HeurMode_t::A})); + CHECK_CUDNN_FE_ERROR(graph_.check_support(handle_)); + CHECK_CUDNN_FE_ERROR(graph_.build_plans(handle_)); + CHECK_CUDNN_FE_ERROR(graph_.get_workspace_size(workspace_size_)); } void run( @@ -208,25 +108,8 @@ class Convolution { {output_attr_->get_uid(), output}}; auto capture = encoder.capture_context(); - CHECK_CUDNN_ERROR(cudnnSetStream(handle_, encoder.stream())); - CHECK_CUDNN_FRONTEND_ERROR( + CHECK_CUDNN_FE_ERROR( graph_.execute(handle_, ptr_map, workspace.data())); -#if 0 - CHECK_CUDNN_ERROR(cudnnConvolutionForward( - handle_, - &alpha, - input_desc_, - input, - filter_desc_, - filter, - conv_desc_, - algo_, - workspace.data(), - workspace_size_, - &beta, - output_desc_, - output)); -#endif } private: @@ -264,7 +147,7 @@ class Convolution { std::shared_ptr input_attr_; std::shared_ptr filter_attr_; std::shared_ptr output_attr_; - size_t workspace_size_{0}; + int64_t workspace_size_{0}; }; } // namespace cu @@ -295,6 +178,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out) { output_strides, std::vector(kernel_strides_.begin(), kernel_strides_.end()), std::vector(padding_lo_.begin(), padding_lo_.end()), + std::vector(padding_hi_.begin(), padding_hi_.end()), std::vector(kernel_dilation_.begin(), kernel_dilation_.end()), groups_); conv.run(encoder, in.data(), wt.data(), out.data()); diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 118c6517c..ff0ff695c 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -9,12 +9,23 @@ #include #include -namespace mlx::core { +namespace mlx::core::cu { + +namespace { // Can be tuned with MLX_MAX_OPS_PER_BUFFER // This should be less than 255 constexpr int default_max_nodes_per_graph = 20; +#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) + +void check_cudnn_error(const char* name, cudnnStatus_t err) { + if (err != CUDNN_STATUS_SUCCESS) { + throw std::runtime_error( + fmt::format("{} failed: {}.", name, cudnnGetErrorString(err))); + } +} + int cuda_graph_cache_size() { static int cache_size = []() { return env::get_var("MLX_CUDA_GRAPH_CACHE_SIZE", 100); @@ -22,7 +33,7 @@ int cuda_graph_cache_size() { return cache_size; } -namespace cu { +} // namespace Device::Device(int device) : device_(device) { CHECK_CUDA_ERROR(cudaDeviceGetAttribute( @@ -42,11 +53,11 @@ Device::Device(int device) : device_(device) { make_current(); cublasLtCreate(<_); // The cudnn handle is used by Convolution. - cudnnCreate(&cudnn_); + CHECK_CUDNN_ERROR(cudnnCreate(&cudnn_)); } Device::~Device() { - cudnnDestroy(cudnn_); + CHECK_CUDNN_ERROR(cudnnDestroy(cudnn_)); cublasLtDestroy(lt_); } @@ -180,6 +191,7 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { CommandEncoder::CommandEncoder(Device& d) : device_(d), stream_(d) { CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); + CHECK_CUDNN_ERROR(cudnnSetStream(d.cudnn_handle(), stream())); } void clear_graphs(std::unordered_map& graphs) { @@ -334,6 +346,4 @@ CommandEncoder& get_command_encoder(Stream s) { return device(s.device).get_command_encoder(s); } -} // namespace cu - -} // namespace mlx::core +} // namespace mlx::core::cu diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 32c480d8e..969bc2ba7 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -3688,7 +3688,17 @@ TEST_CASE("test conv1d") { } TEST_CASE("test conv2d") { - array in = zeros({1, 2, 2, 3}, float32); + auto in = array( + {0.57429284, + -0.21628855, + -0.18673691, + -0.3793517, + + 0.3059678, + -0.8137168, + 0.6168841, + -0.26912728}, + {1, 2, 2, 2}); std::pair kernel{2, 2}; std::pair stride{1, 1}; @@ -3697,7 +3707,15 @@ TEST_CASE("test conv2d") { { int groups = 1; - array wt = ones({1, 2, 2, 3}, float32); + auto wt = array( + {0.3190391, -0.24937038, 1.4621079, -2.0601406, -0.3224172, + -0.38405436, 1.1337694, -1.0998913, -0.1724282, -0.8778584, + 0.04221375, 0.58281523, -1.1006192, 1.1447237, 0.9015907, + 0.50249434, 0.90085596, -0.68372786, -0.12289023, -0.93576944, + -0.26788807, 0.53035545, -0.69166076, -0.39675352, -0.6871727, + -0.84520566, -0.6712461, -0.0126646, -1.1173104, 0.2344157, + 1.6598022, 0.74204415}, + {4, 2, 2, 2}); auto expected = array({1.9549234, -0.98542136, 0.2097499, 0.20991313}, {1, 1, 1, 4});