From aa7b47481a9188407dd99eaaa456c2b87c621ff0 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 8 Aug 2025 15:23:30 +0900 Subject: [PATCH 01/20] [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) --- .../cuda/gemms/cublas_batched_gemm_12_9.cu | 197 ++++++++++++++---- 1 file changed, 154 insertions(+), 43 deletions(-) diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu index 86733fb06..da7163b42 100644 --- a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu @@ -10,7 +10,34 @@ namespace mlx::core::cu { namespace cg = cooperative_groups; -__global__ void set_mm_device_pointers( +template +__global__ void set_mm_device_pointers_nd( + int8_t** pointers, + int8_t* a_start, + int8_t* b_start, + int8_t* out_start, + int item_size, + const __grid_constant__ cuda::std::array batch_shape, + const __grid_constant__ cuda::std::array a_batch_strides, + const __grid_constant__ cuda::std::array b_batch_strides, + int64_t batch_stride, + int batch_count) { + auto index = cg::this_grid().thread_rank(); + if (index >= batch_count) { + return; + } + auto [a_offset, b_offset] = elem_to_loc_nd( + index, + batch_shape.data(), + a_batch_strides.data(), + b_batch_strides.data()); + pointers[index] = a_start + item_size * a_offset; + pointers[index + batch_count] = b_start + item_size * b_offset; + pointers[index + 2 * batch_count] = + out_start + item_size * index * batch_stride; +} + +__global__ void set_mm_device_pointers_g( int8_t** pointers, int8_t* a_start, int8_t* b_start, @@ -38,7 +65,38 @@ __global__ void set_mm_device_pointers( out_start + item_size * index * batch_stride; } -__global__ void set_addmm_device_pointers( +template +__global__ void set_addmm_device_pointers_nd( + int8_t** pointers, + int8_t* a_start, + int8_t* b_start, + int8_t* c_start, + int8_t* out_start, + int item_size, + const __grid_constant__ cuda::std::array batch_shape, + const __grid_constant__ cuda::std::array a_batch_strides, + const __grid_constant__ cuda::std::array b_batch_strides, + const __grid_constant__ cuda::std::array c_batch_strides, + int64_t batch_stride, + int batch_count) { + auto index = cg::this_grid().thread_rank(); + if (index >= batch_count) { + return; + } + auto [a_offset, b_offset, c_offset] = elem_to_loc_nd( + index, + batch_shape.data(), + a_batch_strides.data(), + b_batch_strides.data(), + c_batch_strides.data()); + pointers[index] = a_start + item_size * a_offset; + pointers[index + batch_count] = b_start + item_size * b_offset; + pointers[index + 2 * batch_count] = c_start + item_size * c_offset; + pointers[index + 3 * batch_count] = + out_start + item_size * index * batch_stride; +} + +__global__ void set_addmm_device_pointers_g( int8_t** pointers, int8_t* a_start, int8_t* b_start, @@ -89,37 +147,62 @@ void Matmul::run_batched( const mlx::core::Shape& batch_shape, const mlx::core::Strides& a_batch_strides, const mlx::core::Strides& b_batch_strides) { - auto batch_count = out.size() / (M_ * N_); + int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); set_pointer_mode(out_desc_, batch_count); // Launch kernel to set device offsets auto pointers = array( - allocator::malloc(batch_count * sizeof(uint64_t) * 3), - {static_cast(batch_count * 3)}, + allocator::malloc(batch_count * sizeof(void*) * 3), + {batch_count * 3}, uint64); encoder.add_temporary(pointers); - int block_size = 512; encoder.set_output_array(pointers); - encoder.add_kernel_node( - cu::set_mm_device_pointers, - cuda::ceil_div(pointers.size(), block_size), - block_size, - 0, - pointers.data(), - a.data(), - b.data(), - out.data(), - static_cast(out.dtype().size()), - const_param(batch_shape), - const_param(a_batch_strides), - const_param(b_batch_strides), - static_cast(M_) * N_, - static_cast(batch_shape.size()), - batch_count); + int block_dims = std::min(batch_count, 256); + int num_blocks = cuda::ceil_div(batch_count, block_dims); + int64_t batch_stride = M_ * N_; + int item_size = out.itemsize(); + + int ndim = batch_shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + encoder.add_kernel_node( + cu::set_mm_device_pointers_nd, + num_blocks, + block_dims, + 0, + pointers.data(), + a.data(), + b.data(), + out.data(), + item_size, + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + batch_stride, + batch_count); + }); + } else { + encoder.add_kernel_node( + cu::set_mm_device_pointers_g, + num_blocks, + block_dims, + 0, + pointers.data(), + a.data(), + b.data(), + out.data(), + item_size, + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + batch_stride, + ndim, + batch_count); + } // Run matmul encoder.set_input_array(pointers); @@ -150,7 +233,7 @@ void Matmul::run_batched( const mlx::core::Strides& c_batch_strides, float alpha, float beta) { - auto batch_count = out.size() / (M_ * N_); + int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); set_pointer_mode(c_desc_, batch_count); @@ -159,30 +242,58 @@ void Matmul::run_batched( // Launch kernel to set device offsets auto pointers = array( allocator::malloc(batch_count * sizeof(uint64_t) * 4), - {static_cast(batch_count * 4)}, + {batch_count * 4}, uint64); encoder.add_temporary(pointers); - int block_size = 512; encoder.set_output_array(pointers); - encoder.add_kernel_node( - cu::set_addmm_device_pointers, - cuda::ceil_div(pointers.size(), block_size), - block_size, - 0, - pointers.data(), - a.data(), - b.data(), - c.data(), - out.data(), - static_cast(out.dtype().size()), - const_param(batch_shape), - const_param(a_batch_strides), - const_param(b_batch_strides), - const_param(c_batch_strides), - static_cast(M_) * N_, - static_cast(batch_shape.size()), - batch_count); + + int block_dims = std::min(batch_count, 256); + int num_blocks = cuda::ceil_div(batch_count, block_dims); + int64_t batch_stride = M_ * N_; + int item_size = out.itemsize(); + + int ndim = batch_shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + encoder.add_kernel_node( + cu::set_addmm_device_pointers_nd, + num_blocks, + block_dims, + 0, + pointers.data(), + a.data(), + b.data(), + c.data(), + out.data(), + item_size, + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + const_param(c_batch_strides), + batch_stride, + batch_count); + }); + } else { + encoder.add_kernel_node( + cu::set_addmm_device_pointers_g, + num_blocks, + block_dims, + 0, + pointers.data(), + a.data(), + b.data(), + c.data(), + out.data(), + item_size, + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + const_param(c_batch_strides), + batch_stride, + ndim, + batch_count); + } // Run matmul encoder.set_input_array(pointers); From 7fde1b6a1e9af278d9e219ec0452229e768e4792 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 9 Aug 2025 06:07:17 +0900 Subject: [PATCH 02/20] Fix logsumexp/softmax not fused for some cases (#2474) --- mlx/ops.cpp | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 6c4f76424..14e135deb 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2381,9 +2381,20 @@ array logsumexp( throw std::invalid_argument( "[logsumexp] Received non-empty axes for array with 0 dimensions."); } + bool reduce_last_dim = + !axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1); + if (reduce_last_dim) { + // For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape + // is [1, 1, ..., N]. + for (int i = axes.size() - 2; i >= 0; --i) { + if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) { + reduce_last_dim = false; + break; + } + } + } bool is_complex = issubdtype(a.dtype(), complexfloating); - if (!is_complex && axes.size() == 1 && - (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + if (!is_complex && reduce_last_dim) { auto dtype = at_least_float(a.dtype()); auto out_shape = a.shape(); out_shape.back() = 1; @@ -3403,10 +3414,20 @@ array softmax( throw std::invalid_argument( "[softmax] Received non-empty axes for array with 0 dimensions."); } - + bool reduce_last_dim = + !axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1); + if (reduce_last_dim) { + // For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape + // is [1, 1, ..., N]. + for (int i = axes.size() - 2; i >= 0; --i) { + if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) { + reduce_last_dim = false; + break; + } + } + } bool is_complex = issubdtype(a.dtype(), complexfloating); - if (!is_complex && axes.size() == 1 && - (a.ndim() == axes[0] + 1 || axes[0] == -1)) { + if (!is_complex && reduce_last_dim) { auto dtype = at_least_float(a.dtype()); return array( a.shape(), From 8ae4a763085ac90b6e27f01faa32df06dab0305a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 12 Aug 2025 00:03:42 -0700 Subject: [PATCH 03/20] Use CMake <4.1 to avoid the nvpl error (#2489) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6fcd5d16c..4fc9675df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,6 @@ requires = [ "setuptools>=80", "nanobind==2.4.0", - "cmake>=3.25", + "cmake>=3.25,<4.1", ] build-backend = "setuptools.build_meta" From fce53b61d6a93f3a86e74b0a8a3bc86547228c11 Mon Sep 17 00:00:00 2001 From: Abe Leininger <95333017+abeleinin@users.noreply.github.com> Date: Tue, 12 Aug 2025 02:05:33 -0500 Subject: [PATCH 04/20] Fix reduce sum/prod overflow (#2477) --- mlx/backend/cpu/reduce.cpp | 14 ++++++++--- mlx/backend/metal/kernels/reduce.metal | 4 ++++ mlx/backend/metal/reduce.cpp | 20 ++++++++++++---- tests/gpu_tests.cpp | 13 +++++++++++ tests/ops_tests.cpp | 32 ++++++++++++++++++++++++++ 5 files changed, 75 insertions(+), 8 deletions(-) diff --git a/mlx/backend/cpu/reduce.cpp b/mlx/backend/cpu/reduce.cpp index 8febbd050..41764f4c8 100644 --- a/mlx/backend/cpu/reduce.cpp +++ b/mlx/backend/cpu/reduce.cpp @@ -491,19 +491,27 @@ void Reduce::eval_cpu(const std::vector& inputs, array& out) { switch (in.dtype()) { case bool_: case uint8: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case uint16: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case uint32: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; + case uint64: + reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); + break; case int8: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int16: - case uint16: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int32: - case uint32: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case int64: - case uint64: reduce_dispatch_sum_prod(in, out, reduce_type_, axes_); break; case float16: diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 428f65012..de5dfbad7 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -134,6 +134,10 @@ instantiate_and_or(and, And) instantiate_and_or(or, Or) #define instantiate_sum_prod(name, op) \ + instantiate_reduce_functions(name, uint8, uint8_t, int32_t, op) \ + instantiate_reduce_functions(name, uint16, uint16_t, uint32_t, op) \ + instantiate_reduce_functions(name, uint32, uint32_t, uint32_t, op) \ + instantiate_reduce_functions(name, uint64, uint64_t, uint64_t, op) \ instantiate_reduce_functions(name, int8, int8_t, int32_t, op) \ instantiate_reduce_functions(name, int16, int16_t, int32_t, op) \ instantiate_reduce_functions(name, int32, int32_t, int32_t, op) \ diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 3ae766ba9..504943d82 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -247,15 +247,25 @@ std::pair remap_reduce_types( const std::string& op_name) { if (op_name == "sum" || op_name == "prod") { if (issubdtype(in.dtype(), integer)) { - switch (in.dtype().size()) { - case 1: + switch (in.dtype()) { + case uint8: + return {uint8, uint32}; + case uint16: + return {uint16, uint32}; + case uint32: + return {uint32, uint32}; + case uint64: + return {uint64, uint64}; + case int8: return {int8, int32}; - case 2: + case int16: return {int16, int32}; - case 4: + case int32: return {int32, int32}; - case 8: + case int64: return {int64, int64}; + default: + throw std::runtime_error("Unsupported integer type"); } } if (in.dtype() == bool_) { diff --git a/tests/gpu_tests.cpp b/tests/gpu_tests.cpp index f0ef969cf..625cbf552 100644 --- a/tests/gpu_tests.cpp +++ b/tests/gpu_tests.cpp @@ -155,6 +155,19 @@ TEST_CASE("test gpu reduce") { CHECK_EQ(prod(a, Device::gpu).item(), 1); } + // sum and prod overflow + { + auto a = full({256, 2, 2}, 1u, uint8); + CHECK_EQ(sum(a, Device::gpu).item(), 256 * 4); + CHECK_EQ(prod(a, Device::gpu).item(), 1); + + a = full({65535, 2, 2}, 1u, uint16); + CHECK_EQ(sum(a, Device::gpu).item(), 65535 * 4); + CHECK_EQ(prod(a, Device::gpu).item(), 1); + } +} + +TEST_CASE("test gpu reduce with axes") { // reducing only some axes and irregular layouts { array a(1.0f); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 969bc2ba7..17207efd4 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -915,6 +915,23 @@ TEST_CASE("test reduction ops") { CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item()); } + // Test unsigned sum + { + const int num_elems = 1000; + + auto x = astype(full({num_elems}, 255), uint8); + CHECK_EQ(sum(x, Device::cpu).item(), 255 * num_elems); + + x = astype(full({num_elems}, 65535), uint16); + CHECK_EQ(sum(x, Device::cpu).item(), 65535 * num_elems); + + x = full({3, 3, 3}, 10000, uint32); + CHECK_EQ(sum(x, Device::cpu).item(), 270000); + + x = full({3, 3, 3}, 10000, uint64); + CHECK_EQ(sum(x, Device::cpu).item(), 270000); + } + // Test prod { auto x = array({}); @@ -947,6 +964,21 @@ TEST_CASE("test reduction ops") { CHECK(array_equal(prod(x, 1), array({true, false})).item()); } + // Test unsigned prod + { + auto x = array({255, 255}, {2}, uint8); + CHECK_EQ(prod(x, Device::cpu).item(), 65025); + + x = array({65535, 2}, {2}, uint16); + CHECK_EQ(prod(x, Device::cpu).item(), 131070); + + x = array({100000, 2}, {2}, uint32); + CHECK_EQ(prod(x, Device::cpu).item(), 200000); + + x = array({100000, 2}, {2}, uint64); + CHECK_EQ(prod(x, Device::cpu).item(), 200000); + } + // Test all { auto x = array({}); From ac207ce7aaa73b8be3909f9fe575f60870ff179e Mon Sep 17 00:00:00 2001 From: Daniel Yeh <46629671+Dan-Yeh@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:29:02 +0200 Subject: [PATCH 05/20] make code blocks copyable (#2480) Co-authored-by: Chen-Chen Yeh --- docs/requirements.txt | 1 + docs/src/conf.py | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 9c8ff52f4..06404fa75 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,5 @@ sphinx breathe sphinx-book-theme +sphinx-copybutton mlx diff --git a/docs/src/conf.py b/docs/src/conf.py index d9dd32ad1..446d95cd2 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -18,6 +18,7 @@ release = version # -- General configuration --------------------------------------------------- extensions = [ + "sphinx_copybutton", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.intersphinx", From dfb5022eab342ebe46683aaf971504fc42fcd862 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 13 Aug 2025 09:37:40 +0900 Subject: [PATCH 06/20] Rename cu::Matmul to CublasGemm (#2488) --- mlx/backend/cuda/CMakeLists.txt | 4 +- mlx/backend/cuda/gemms/cublas_gemm.cpp | 121 ++++++++++++------ mlx/backend/cuda/gemms/cublas_gemm.h | 55 +++++--- ..._12_0.cpp => cublas_gemm_batched_12_0.cpp} | 26 ++-- ...mm_12_9.cu => cublas_gemm_batched_12_9.cu} | 34 +++-- mlx/backend/cuda/matmul.cpp | 20 +-- 6 files changed, 157 insertions(+), 103 deletions(-) rename mlx/backend/cuda/gemms/{cublas_batched_gemm_12_0.cpp => cublas_gemm_batched_12_0.cpp} (80%) rename mlx/backend/cuda/gemms/{cublas_batched_gemm_12_9.cu => cublas_gemm_batched_12_9.cu} (95%) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 59db14535..084e449b0 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -53,10 +53,10 @@ target_sources( if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) target_sources( - mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) else() target_sources( - mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp) endif() target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 61f12ba1d..1aeeefa38 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -7,10 +7,12 @@ #include -namespace mlx::core::cu { +namespace mlx::core { + +namespace { struct CublasPreference { - CublasPreference(Device& device) { + CublasPreference(cu::Device& device) { // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB // for Hopper+: // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace @@ -33,7 +35,7 @@ struct CublasPreference { cublasLtMatmulPreference_t pref_{nullptr}; }; -cublasLtMatmulPreference_t cublas_preference(Device& device) { +cublasLtMatmulPreference_t cublas_preference(cu::Device& device) { static CublasPreference pref(device); return pref.pref_; } @@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) { return CUBLAS_COMPUTE_64F; default: throw std::runtime_error(fmt::format( - "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); + "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); } } @@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) { return CUDA_C_32F; default: throw std::runtime_error(fmt::format( - "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); + "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); } } @@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout( return desc; } -Matmul::Matmul( - Device& device, +} // namespace + +CublasGemm::CublasGemm( + cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, @@ -155,8 +159,8 @@ Matmul::Matmul( type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); } -Matmul::Matmul( - Device& device, +CublasGemm::CublasGemm( + cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, @@ -171,7 +175,7 @@ Matmul::Matmul( int64_t a_batch_stride, int64_t b_batch_stride, int64_t c_batch_stride) - : Matmul( + : CublasGemm( device, dtype, a_transposed, @@ -190,7 +194,7 @@ Matmul::Matmul( type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); } -Matmul::~Matmul() { +CublasGemm::~CublasGemm() { CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); @@ -198,7 +202,73 @@ Matmul::~Matmul() { CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); } -void Matmul::run_impl( +void CublasGemm::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides) { + int batch_count = out.size() / (M_ * N_); + if (batch_count / batch_shape.back() > 1) { + run_batched( + encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); + return; + } + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + execute(encoder, out.data(), a.data(), b.data(), nullptr); +} + +void CublasGemm::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, + float alpha, + float beta) { + int batch_count = out.size() / (M_ * N_); + if (batch_count / batch_shape.back() > 1) { + run_batched( + encoder, + out, + a, + b, + c, + batch_shape, + a_batch_strides, + b_batch_strides, + c_batch_strides, + alpha, + beta); + return; + } + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + execute( + encoder, + out.data(), + a.data(), + b.data(), + c.data(), + alpha, + beta); +} + +void CublasGemm::execute( cu::CommandEncoder& encoder, void* out, const void* a, @@ -256,29 +326,4 @@ void Matmul::run_impl( encoder.stream())); } -void Matmul::run( - cu::CommandEncoder& encoder, - array& out, - const array& a, - const array& b, - const std::optional& c /* = std::nullopt */, - float alpha /* = 1 */, - float beta /* = 0 */) { - encoder.set_input_array(a); - encoder.set_input_array(b); - if (c) { - encoder.set_input_array(*c); - } - encoder.set_output_array(out); - - run_impl( - encoder, - out.data(), - a.data(), - b.data(), - c ? c->data() : nullptr, - alpha, - beta); -} - -} // namespace mlx::core::cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index eccee8580..e093351b6 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -5,13 +5,13 @@ #include "mlx/backend/cuda/device.h" #include -#include -namespace mlx::core::cu { -class Matmul { +namespace mlx::core { + +class CublasGemm { public: - Matmul( - Device& device, + CublasGemm( + cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, @@ -25,8 +25,8 @@ class Matmul { int64_t a_batch_stride, int64_t b_batch_stride); - Matmul( - Device& device, + CublasGemm( + cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, @@ -42,25 +42,39 @@ class Matmul { int64_t b_batch_stride, int64_t c_batch_stride); - ~Matmul(); + ~CublasGemm(); void run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, - const std::optional& c = std::nullopt, - float alpha = 1, - float beta = 0); + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides); + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, + float alpha, + float beta); + + private: void run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides); + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides); void run_batched( cu::CommandEncoder& encoder, @@ -68,15 +82,14 @@ class Matmul { const array& a, const array& b, const array& c, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides, - const mlx::core::Strides& c_batch_strides, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, float alpha, float beta); - private: - void run_impl( + void execute( cu::CommandEncoder& encoder, void* out, const void* a, @@ -97,4 +110,4 @@ class Matmul { cublasLtMatmulHeuristicResult_t heuristic_; }; -} // namespace mlx::core::cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp similarity index 80% rename from mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp rename to mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp index 39a8a5ddd..56c731587 100644 --- a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp @@ -4,16 +4,16 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" -namespace mlx::core::cu { +namespace mlx::core { -void Matmul::run_batched( +void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides) { + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); @@ -22,7 +22,7 @@ void Matmul::run_batched( ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { - run_impl( + execute( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, a.data() + a.itemsize() * a_it.loc, @@ -33,16 +33,16 @@ void Matmul::run_batched( } } -void Matmul::run_batched( +void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides, - const mlx::core::Strides& c_batch_strides, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, float alpha, float beta) { encoder.set_input_array(a); @@ -56,7 +56,7 @@ void Matmul::run_batched( ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { - run_impl( + execute( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, a.data() + a.itemsize() * a_it.loc, @@ -70,4 +70,4 @@ void Matmul::run_batched( } } -} // namespace mlx::core::cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu similarity index 95% rename from mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu rename to mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu index da7163b42..570b79463 100644 --- a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu @@ -6,7 +6,9 @@ #include -namespace mlx::core::cu { +namespace mlx::core { + +namespace cu { namespace cg = cooperative_groups; @@ -128,6 +130,10 @@ __global__ void set_addmm_device_pointers_g( out_start + item_size * index * batch_stride; } +} // namespace cu + +namespace { + void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( @@ -139,14 +145,16 @@ void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); } -void Matmul::run_batched( +} // namespace + +void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides) { + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides) { int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); @@ -213,7 +221,7 @@ void Matmul::run_batched( auto a_pointers = pointers.data(); auto b_pointers = a_pointers + batch_count; auto out_pointers = b_pointers + batch_count; - run_impl( + execute( encoder, reinterpret_cast(out_pointers), reinterpret_cast(a_pointers), @@ -221,16 +229,16 @@ void Matmul::run_batched( nullptr); } -void Matmul::run_batched( +void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides, - const mlx::core::Strides& c_batch_strides, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, float alpha, float beta) { int batch_count = out.size() / (M_ * N_); @@ -306,7 +314,7 @@ void Matmul::run_batched( auto b_pointers = a_pointers + batch_count; auto c_pointers = b_pointers + batch_count; auto out_pointers = c_pointers + batch_count; - run_impl( + execute( encoder, reinterpret_cast(out_pointers), reinterpret_cast(a_pointers), @@ -316,4 +324,4 @@ void Matmul::run_batched( beta); } -} // namespace mlx::core::cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 283aaaf2e..b11fae538 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -97,7 +97,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt - cu::Matmul matmul( + CublasGemm gemm( cu::device(s.device), a.dtype(), a_transposed, @@ -111,14 +111,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { batch_shape.back(), a_batch_strides.back(), b_batch_strides.back()); - - if ((batch_count / batch_shape.back()) == 1) { - matmul.run(encoder, out, a, b); - return; - } - - matmul.run_batched( - encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); + gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { @@ -186,7 +179,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt - cu::Matmul matmul( + CublasGemm gemm( cu::device(s.device), a.dtype(), a_transposed, @@ -202,12 +195,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { a_batch_strides.back(), b_batch_strides.back(), c_batch_strides.back()); - - if ((batch_count / batch_shape.back()) == 1) { - matmul.run(encoder, out, a, b, c, alpha_, beta_); - return; - } - matmul.run_batched( + gemm.run( encoder, out, a, From 6441c21a944b710167f9ad5bee35c02ed8b8599a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 15 Aug 2025 15:04:12 -0700 Subject: [PATCH 07/20] Faster general unary op (#2472) * faster general unary op * faster general ops + reorg * fix + comment * binary two * copy general --- mlx/backend/cuda/CMakeLists.txt | 5 +- mlx/backend/cuda/binary/CMakeLists.txt | 21 ++ mlx/backend/cuda/binary/add.cu | 7 + mlx/backend/cuda/binary/arctan2.cu | 7 + .../cuda/{binary.cu => binary/binary.cuh} | 178 ++++++++------- mlx/backend/cuda/binary/bitwise_binary.cu | 27 +++ mlx/backend/cuda/binary/divide.cu | 7 + mlx/backend/cuda/binary/equal.cu | 15 ++ mlx/backend/cuda/binary/greater.cu | 7 + mlx/backend/cuda/binary/greater_equal.cu | 7 + mlx/backend/cuda/binary/less.cu | 7 + mlx/backend/cuda/binary/less_equal.cu | 7 + mlx/backend/cuda/binary/log_add_exp.cu | 7 + mlx/backend/cuda/binary/logical_and.cu | 7 + mlx/backend/cuda/binary/logical_or.cu | 7 + mlx/backend/cuda/binary/maximum.cu | 7 + mlx/backend/cuda/binary/minimum.cu | 7 + mlx/backend/cuda/binary/multiply.cu | 7 + mlx/backend/cuda/binary/not_equal.cu | 7 + mlx/backend/cuda/binary/power.cu | 7 + mlx/backend/cuda/binary/remainder.cu | 7 + mlx/backend/cuda/binary/subtract.cu | 7 + mlx/backend/cuda/binary_two.cu | 142 +++++++++--- mlx/backend/cuda/copy/copy_general.cu | 110 +++++++-- mlx/backend/cuda/copy/copy_general_input.cu | 97 ++++++-- mlx/backend/cuda/device/utils.cuh | 17 ++ mlx/backend/cuda/ternary.cu | 127 ++++++++--- mlx/backend/cuda/unary.cu | 54 +++-- mlx/backend/cuda/unary/CMakeLists.txt | 34 +++ mlx/backend/cuda/unary/abs.cu | 7 + mlx/backend/cuda/unary/arccos.cu | 7 + mlx/backend/cuda/unary/arccosh.cu | 7 + mlx/backend/cuda/unary/arcsin.cu | 7 + mlx/backend/cuda/unary/arcsinh.cu | 7 + mlx/backend/cuda/unary/arctan.cu | 7 + mlx/backend/cuda/unary/arctanh.cu | 7 + mlx/backend/cuda/unary/bitwise_invert.cu | 7 + mlx/backend/cuda/unary/ceil.cu | 7 + mlx/backend/cuda/unary/conjugate.cu | 7 + mlx/backend/cuda/unary/cos.cu | 7 + mlx/backend/cuda/unary/cosh.cu | 7 + mlx/backend/cuda/unary/erf.cu | 7 + mlx/backend/cuda/unary/erf_inv.cu | 7 + mlx/backend/cuda/unary/exp.cu | 7 + mlx/backend/cuda/unary/expm1.cu | 7 + mlx/backend/cuda/unary/floor.cu | 7 + mlx/backend/cuda/unary/imag.cu | 7 + mlx/backend/cuda/unary/log.cu | 21 ++ mlx/backend/cuda/unary/log1p.cu | 7 + mlx/backend/cuda/unary/logical_not.cu | 7 + mlx/backend/cuda/unary/negative.cu | 7 + mlx/backend/cuda/unary/real.cu | 7 + mlx/backend/cuda/unary/round.cu | 18 ++ mlx/backend/cuda/unary/sigmoid.cu | 7 + mlx/backend/cuda/unary/sign.cu | 7 + mlx/backend/cuda/unary/sin.cu | 7 + mlx/backend/cuda/unary/sinh.cu | 7 + mlx/backend/cuda/unary/sqrt.cu | 15 ++ mlx/backend/cuda/unary/square.cu | 7 + mlx/backend/cuda/unary/tan.cu | 7 + mlx/backend/cuda/unary/tanh.cu | 7 + mlx/backend/cuda/unary/unary.cuh | 215 ++++++++++++++++++ 62 files changed, 1215 insertions(+), 203 deletions(-) create mode 100644 mlx/backend/cuda/binary/CMakeLists.txt create mode 100644 mlx/backend/cuda/binary/add.cu create mode 100644 mlx/backend/cuda/binary/arctan2.cu rename mlx/backend/cuda/{binary.cu => binary/binary.cuh} (72%) create mode 100644 mlx/backend/cuda/binary/bitwise_binary.cu create mode 100644 mlx/backend/cuda/binary/divide.cu create mode 100644 mlx/backend/cuda/binary/equal.cu create mode 100644 mlx/backend/cuda/binary/greater.cu create mode 100644 mlx/backend/cuda/binary/greater_equal.cu create mode 100644 mlx/backend/cuda/binary/less.cu create mode 100644 mlx/backend/cuda/binary/less_equal.cu create mode 100644 mlx/backend/cuda/binary/log_add_exp.cu create mode 100644 mlx/backend/cuda/binary/logical_and.cu create mode 100644 mlx/backend/cuda/binary/logical_or.cu create mode 100644 mlx/backend/cuda/binary/maximum.cu create mode 100644 mlx/backend/cuda/binary/minimum.cu create mode 100644 mlx/backend/cuda/binary/multiply.cu create mode 100644 mlx/backend/cuda/binary/not_equal.cu create mode 100644 mlx/backend/cuda/binary/power.cu create mode 100644 mlx/backend/cuda/binary/remainder.cu create mode 100644 mlx/backend/cuda/binary/subtract.cu create mode 100644 mlx/backend/cuda/unary/CMakeLists.txt create mode 100644 mlx/backend/cuda/unary/abs.cu create mode 100644 mlx/backend/cuda/unary/arccos.cu create mode 100644 mlx/backend/cuda/unary/arccosh.cu create mode 100644 mlx/backend/cuda/unary/arcsin.cu create mode 100644 mlx/backend/cuda/unary/arcsinh.cu create mode 100644 mlx/backend/cuda/unary/arctan.cu create mode 100644 mlx/backend/cuda/unary/arctanh.cu create mode 100644 mlx/backend/cuda/unary/bitwise_invert.cu create mode 100644 mlx/backend/cuda/unary/ceil.cu create mode 100644 mlx/backend/cuda/unary/conjugate.cu create mode 100644 mlx/backend/cuda/unary/cos.cu create mode 100644 mlx/backend/cuda/unary/cosh.cu create mode 100644 mlx/backend/cuda/unary/erf.cu create mode 100644 mlx/backend/cuda/unary/erf_inv.cu create mode 100644 mlx/backend/cuda/unary/exp.cu create mode 100644 mlx/backend/cuda/unary/expm1.cu create mode 100644 mlx/backend/cuda/unary/floor.cu create mode 100644 mlx/backend/cuda/unary/imag.cu create mode 100644 mlx/backend/cuda/unary/log.cu create mode 100644 mlx/backend/cuda/unary/log1p.cu create mode 100644 mlx/backend/cuda/unary/logical_not.cu create mode 100644 mlx/backend/cuda/unary/negative.cu create mode 100644 mlx/backend/cuda/unary/real.cu create mode 100644 mlx/backend/cuda/unary/round.cu create mode 100644 mlx/backend/cuda/unary/sigmoid.cu create mode 100644 mlx/backend/cuda/unary/sign.cu create mode 100644 mlx/backend/cuda/unary/sin.cu create mode 100644 mlx/backend/cuda/unary/sinh.cu create mode 100644 mlx/backend/cuda/unary/sqrt.cu create mode 100644 mlx/backend/cuda/unary/square.cu create mode 100644 mlx/backend/cuda/unary/tan.cu create mode 100644 mlx/backend/cuda/unary/tanh.cu create mode 100644 mlx/backend/cuda/unary/unary.cuh diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 084e449b0..0d526400d 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -8,7 +8,6 @@ target_sources( PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu - ${CMAKE_CURRENT_SOURCE_DIR}/binary.cu ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu @@ -45,12 +44,14 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu - ${CMAKE_CURRENT_SOURCE_DIR}/unary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) diff --git a/mlx/backend/cuda/binary/CMakeLists.txt b/mlx/backend/cuda/binary/CMakeLists.txt new file mode 100644 index 000000000..bda289de7 --- /dev/null +++ b/mlx/backend/cuda/binary/CMakeLists.txt @@ -0,0 +1,21 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu) diff --git a/mlx/backend/cuda/binary/add.cu b/mlx/backend/cuda/binary/add.cu new file mode 100644 index 000000000..87dfd7e70 --- /dev/null +++ b/mlx/backend/cuda/binary/add.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Add) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/arctan2.cu b/mlx/backend/cuda/binary/arctan2.cu new file mode 100644 index 000000000..2fd7e3922 --- /dev/null +++ b/mlx/backend/cuda/binary/arctan2.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(ArcTan2) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary/binary.cuh similarity index 72% rename from mlx/backend/cuda/binary.cu rename to mlx/backend/cuda/binary/binary.cuh index 0243d4f41..20bb199ec 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary/binary.cuh @@ -99,39 +99,89 @@ __global__ void binary_vv(const In* a, const In* b, Out* out, IdxT size) { } } -template +template < + typename Op, + typename In, + typename Out, + typename IdxT, + int NDIM, + int N_READS> __global__ void binary_g_nd( const In* a, const In* b, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx] = elem_to_loc_nd( - index, shape.data(), a_strides.data(), b_strides.data()); - out[index] = Op{}(a[a_idx], b[b_idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto a_stride_x = a_strides[NDIM - 1]; + auto b_stride_x = b_strides[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx] = elem_to_loc_nd( + index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data()); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(a_vec[i], b_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } -template +template __global__ void binary_g( const In* a, const In* b, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx] = elem_to_loc( - index, shape.data(), a_strides.data(), b_strides.data(), ndim); - out[index] = Op{}(a[a_idx], b[b_idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto a_stride_x = a_strides[ndim - 1]; + auto b_stride_x = b_strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx] = elem_to_loc( + index_rest * shape_x, + shape.data(), + a_strides.data(), + b_strides.data(), + ndim); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(a_vec[i], b_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template @@ -209,39 +259,61 @@ void binary_op_gpu_inplace( auto& a_strides = strides[0]; auto& b_strides = strides[1]; int ndim = shape.size(); + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto [num_blocks, block_dims] = - get_launch_args(out, large()); + auto kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant(), + 1>; + if (work_per_thread == 4) { + kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant(), + 4>; + } encoder.add_kernel_node( - cu::binary_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), out.data(), - out.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = cu::binary_g; + if (work_per_thread == 4) { + kernel = cu::binary_g; + } encoder.add_kernel_node( - cu::binary_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), out.data(), - out.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides), @@ -304,54 +376,4 @@ void binary_op_gpu( binary_op_gpu(inputs, out, name(), s); \ } -BINARY_GPU(Add) -BINARY_GPU(ArcTan2) -BINARY_GPU(Divide) -BINARY_GPU(Remainder) -BINARY_GPU(Greater) -BINARY_GPU(GreaterEqual) -BINARY_GPU(Less) -BINARY_GPU(LessEqual) -BINARY_GPU(LogicalAnd) -BINARY_GPU(LogicalOr) -BINARY_GPU(LogAddExp) -BINARY_GPU(Maximum) -BINARY_GPU(Minimum) -BINARY_GPU(Multiply) -BINARY_GPU(NotEqual) -BINARY_GPU(Power) -BINARY_GPU(Subtract) - -void Equal::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("Equal::eval_gpu"); - auto& s = out.primitive().stream(); - if (equal_nan_) { - binary_op_gpu(inputs, out, name(), s); - } else { - binary_op_gpu(inputs, out, name(), s); - } -} - -void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); - auto& s = out.primitive().stream(); - switch (op_) { - case BitwiseBinary::And: - binary_op_gpu(inputs, out, name(), s); - break; - case BitwiseBinary::Or: - binary_op_gpu(inputs, out, name(), s); - break; - case BitwiseBinary::Xor: - binary_op_gpu(inputs, out, name(), s); - break; - case BitwiseBinary::LeftShift: - binary_op_gpu(inputs, out, name(), s); - break; - case BitwiseBinary::RightShift: - binary_op_gpu(inputs, out, name(), s); - break; - } -} - } // namespace mlx::core diff --git a/mlx/backend/cuda/binary/bitwise_binary.cu b/mlx/backend/cuda/binary/bitwise_binary.cu new file mode 100644 index 000000000..8025a3bd1 --- /dev/null +++ b/mlx/backend/cuda/binary/bitwise_binary.cu @@ -0,0 +1,27 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +void BitwiseBinary::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("BitwiseBinary::eval_gpu"); + auto& s = out.primitive().stream(); + switch (op_) { + case BitwiseBinary::And: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::Or: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::Xor: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::LeftShift: + binary_op_gpu(inputs, out, name(), s); + break; + case BitwiseBinary::RightShift: + binary_op_gpu(inputs, out, name(), s); + break; + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/divide.cu b/mlx/backend/cuda/binary/divide.cu new file mode 100644 index 000000000..fcf3dc77e --- /dev/null +++ b/mlx/backend/cuda/binary/divide.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Divide) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/equal.cu b/mlx/backend/cuda/binary/equal.cu new file mode 100644 index 000000000..559b3e8ed --- /dev/null +++ b/mlx/backend/cuda/binary/equal.cu @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +void Equal::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Equal::eval_gpu"); + auto& s = out.primitive().stream(); + if (equal_nan_) { + binary_op_gpu(inputs, out, name(), s); + } else { + binary_op_gpu(inputs, out, name(), s); + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/greater.cu b/mlx/backend/cuda/binary/greater.cu new file mode 100644 index 000000000..c9820206b --- /dev/null +++ b/mlx/backend/cuda/binary/greater.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Greater) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/greater_equal.cu b/mlx/backend/cuda/binary/greater_equal.cu new file mode 100644 index 000000000..4666fb4a9 --- /dev/null +++ b/mlx/backend/cuda/binary/greater_equal.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(GreaterEqual) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/less.cu b/mlx/backend/cuda/binary/less.cu new file mode 100644 index 000000000..a2053fa8b --- /dev/null +++ b/mlx/backend/cuda/binary/less.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Less) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/less_equal.cu b/mlx/backend/cuda/binary/less_equal.cu new file mode 100644 index 000000000..7f9bc5161 --- /dev/null +++ b/mlx/backend/cuda/binary/less_equal.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(LessEqual) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/log_add_exp.cu b/mlx/backend/cuda/binary/log_add_exp.cu new file mode 100644 index 000000000..17614f862 --- /dev/null +++ b/mlx/backend/cuda/binary/log_add_exp.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(LogAddExp) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/logical_and.cu b/mlx/backend/cuda/binary/logical_and.cu new file mode 100644 index 000000000..6bbeb1a4c --- /dev/null +++ b/mlx/backend/cuda/binary/logical_and.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(LogicalAnd) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/logical_or.cu b/mlx/backend/cuda/binary/logical_or.cu new file mode 100644 index 000000000..63afdb98c --- /dev/null +++ b/mlx/backend/cuda/binary/logical_or.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(LogicalOr) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/maximum.cu b/mlx/backend/cuda/binary/maximum.cu new file mode 100644 index 000000000..4f6cb6e0b --- /dev/null +++ b/mlx/backend/cuda/binary/maximum.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Maximum) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/minimum.cu b/mlx/backend/cuda/binary/minimum.cu new file mode 100644 index 000000000..ec4c1abb0 --- /dev/null +++ b/mlx/backend/cuda/binary/minimum.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Minimum) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/multiply.cu b/mlx/backend/cuda/binary/multiply.cu new file mode 100644 index 000000000..bfc15fcaa --- /dev/null +++ b/mlx/backend/cuda/binary/multiply.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Multiply) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/not_equal.cu b/mlx/backend/cuda/binary/not_equal.cu new file mode 100644 index 000000000..49f05c90a --- /dev/null +++ b/mlx/backend/cuda/binary/not_equal.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(NotEqual) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/power.cu b/mlx/backend/cuda/binary/power.cu new file mode 100644 index 000000000..cacdc75c4 --- /dev/null +++ b/mlx/backend/cuda/binary/power.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Power) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/remainder.cu b/mlx/backend/cuda/binary/remainder.cu new file mode 100644 index 000000000..a55006ba0 --- /dev/null +++ b/mlx/backend/cuda/binary/remainder.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Remainder) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary/subtract.cu b/mlx/backend/cuda/binary/subtract.cu new file mode 100644 index 000000000..37f3874cc --- /dev/null +++ b/mlx/backend/cuda/binary/subtract.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/binary/binary.cuh" + +namespace mlx::core { +BINARY_GPU(Subtract) +} // namespace mlx::core diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 49a747829..cd0fe2c46 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -127,45 +127,99 @@ binary_two_vv(const In* a, const In* b, Out* out_a, Out* out_b, IdxT size) { } } -template +template < + typename Op, + typename In, + typename Out, + typename IdxT, + int NDIM, + int N_READS> __global__ void binary_two_g_nd( const In* a, const In* b, Out* out_a, Out* out_b, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx] = elem_to_loc_nd( - index, shape.data(), a_strides.data(), b_strides.data()); - auto out = Op{}(a[a_idx], b[b_idx]); - out_a[index] = out[0]; - out_b[index] = out[1]; + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto a_stride_x = a_strides[NDIM - 1]; + auto b_stride_x = b_strides[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx] = elem_to_loc_nd( + index_rest * shape_x, shape.data(), a_strides.data(), b_strides.data()); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); + + AlignedVector out_vec_a; + AlignedVector out_vec_b; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec[i], b_vec[i]); + out_vec_a[i] = out[0]; + out_vec_b[i] = out[1]; + } + store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x); + store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x); } -template +template __global__ void binary_two_g( const In* a, const In* b, Out* out_a, Out* out_b, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx] = elem_to_loc( - index, shape.data(), a_strides.data(), b_strides.data(), ndim); - auto out = Op{}(a[a_idx], b[b_idx]); - out_a[index] = out[0]; - out_b[index] = out[1]; + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto a_stride_x = a_strides[ndim - 1]; + auto b_stride_x = b_strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx] = elem_to_loc( + index_rest * shape_x, + shape.data(), + a_strides.data(), + b_strides.data(), + ndim); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, In(0)); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, In(0)); + + AlignedVector out_vec_a; + AlignedVector out_vec_b; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + auto out = Op{}(a_vec[i], b_vec[i]); + out_vec_a[i] = out[0]; + out_vec_b[i] = out[1]; + } + store_vector(out_a + shape_x * index_rest, index_x, out_vec_a, shape_x); + store_vector(out_b + shape_x * index_rest, index_x, out_vec_b, shape_x); } template @@ -225,42 +279,64 @@ void binary_two_op_gpu_inplace( auto& a_strides = strides[0]; auto& b_strides = strides[1]; int ndim = shape.size(); + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out_a.size() / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto [num_blocks, block_dims] = - get_launch_args(out_a, large()); + auto kernel = cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant(), + 1>; + if (work_per_thread == 4) { + kernel = cu::binary_two_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant(), + 4>; + } encoder.add_kernel_node( - cu::binary_two_g_nd< - Op, - InType, - OutType, - IdxT, - dims_constant()>, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), out_a.data(), out_b.data(), - out_a.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides)); }); } else { - auto [num_blocks, block_dims] = - get_launch_args(out_a, large()); + auto kernel = cu::binary_two_g; + if (work_per_thread == 4) { + kernel = cu::binary_two_g; + } encoder.add_kernel_node( - cu::binary_two_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), out_a.data(), out_b.data(), - out_a.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides), diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 64c67a176..6ac42751a 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -10,37 +10,80 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void copy_gg_nd( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array strides_in, const __grid_constant__ cuda::std::array strides_out) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [idx_in, idx_out] = elem_to_loc_nd( - index, shape.data(), strides_in.data(), strides_out.data()); - out[idx_out] = CastOp{}(in[idx_in]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto in_stride_x = strides_in[NDIM - 1]; + auto out_stride_x = strides_out[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [idx_in, idx_out] = elem_to_loc_nd( + index_rest * shape_x, + shape.data(), + strides_in.data(), + strides_out.data()); + + auto in_vec = + load_vector(in + idx_in, index_x, shape_x, in_stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x); } -template +template __global__ void copy_gg( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides_in, const __grid_constant__ Strides strides_out, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [idx_in, idx_out] = elem_to_loc( - index, shape.data(), strides_in.data(), strides_out.data(), ndim); - out[idx_out] = CastOp{}(in[idx_in]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto in_stride_x = strides_in[ndim - 1]; + auto out_stride_x = strides_out[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [idx_in, idx_out] = elem_to_loc( + index_rest * shape_x, + shape.data(), + strides_in.data(), + strides_out.data(), + ndim); + + auto in_vec = + load_vector(in + idx_in, index_x, shape_x, in_stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + idx_out, index_x, out_vec, shape_x, out_stride_x); } } // namespace cu @@ -69,33 +112,52 @@ void copy_general( size_t data_size = 1; for (auto& s : shape) data_size *= s; + + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = data_size / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto ndim_constant) { - auto [num_blocks, block_dims] = - get_launch_args(data_size, shape, out.strides(), large()); + auto kernel = + cu::copy_gg_nd; + if (work_per_thread == 4) { + kernel = + cu::copy_gg_nd; + } encoder.add_kernel_node( - cu::copy_gg_nd, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - data_size, + rest, const_param(shape), const_param(strides_in), const_param(strides_out)); }); } else { // ndim >= 4 - auto [num_blocks, block_dims] = - get_launch_args(data_size, shape, out.strides(), large()); + auto kernel = cu::copy_gg; + if (work_per_thread == 4) { + kernel = cu::copy_gg; + } encoder.add_kernel_node( - cu::copy_gg, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - data_size, + rest, const_param(shape), const_param(strides_in), const_param(strides_out), diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index f381f14fa..ce8bb1b78 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -10,33 +10,67 @@ namespace cu { namespace cg = cooperative_groups; -template +template __global__ void copy_g_nd( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, - const __grid_constant__ cuda::std::array strides_in) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - IdxT idx_in = elem_to_loc_nd(index, shape.data(), strides_in.data()); - out[index] = CastOp{}(in[idx_in]); + const __grid_constant__ cuda::std::array strides) { + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto stride_x = strides[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = + elem_to_loc_nd(index_rest * shape_x, shape.data(), strides.data()); + auto in_vec = + load_vector(in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } -template +template __global__ void copy_g( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, - const __grid_constant__ Strides strides_in, + const __grid_constant__ Strides strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - IdxT idx_in = elem_to_loc(index, shape.data(), strides_in.data(), ndim); - out[index] = CastOp{}(in[idx_in]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto stride_x = strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = + elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); + auto in_vec = + load_vector(in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = CastOp{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } } // namespace cu @@ -61,30 +95,49 @@ void copy_general_input( const InType* in_ptr = in.data() + offset_in; OutType* out_ptr = out.data() + offset_out; int ndim = shape.size(); + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = + cu::copy_g_nd; + if (work_per_thread == 4) { + kernel = + cu::copy_g_nd; + } encoder.add_kernel_node( - cu::copy_g_nd, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - out.size(), + rest, const_param(shape), const_param(strides_in)); }); } else { // ndim >= 4 - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = cu::copy_g; + if (work_per_thread == 4) { + kernel = cu::copy_g; + } encoder.add_kernel_node( - cu::copy_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in_ptr, out_ptr, - out.size(), + rest, const_param(shape), const_param(strides_in), ndim); diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index bc055c9df..7ebc5d654 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -146,6 +146,23 @@ inline __device__ void store_vector( } } +template +inline __device__ void store_vector( + T* ptr, + uint32_t offset, + const AlignedVector& vec, + SizeT size, + int64_t stride) { + if (is_aligned(ptr) && (offset + 1) * N <= size && stride == 1) { + auto* to = reinterpret_cast*>(ptr); + to[offset] = vec; + } else { + for (int i = 0; (offset * N + i) < size && i < N; ++i) { + ptr[stride * (offset * N + i)] = vec[i]; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index cfc0e10b8..67937fc8e 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -39,52 +39,98 @@ ternary_v(const bool* a, const T* b, const T* c, T* out, IdxT size) { } } -template +template __global__ void ternary_g_nd( const bool* a, const T* b, const T* c, T* out, - IdxT size, + IdxT size_rest, const __grid_constant__ cuda::std::array shape, const __grid_constant__ cuda::std::array a_strides, const __grid_constant__ cuda::std::array b_strides, const __grid_constant__ cuda::std::array c_strides) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( - index, - shape.data(), - a_strides.data(), - b_strides.data(), - c_strides.data()); - out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[NDIM - 1]; + auto a_stride_x = a_strides[NDIM - 1]; + auto b_stride_x = b_strides[NDIM - 1]; + auto c_stride_x = c_strides[NDIM - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx, c_idx] = elem_to_loc_nd( + index_rest * shape_x, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data()); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, false); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, T(0)); + auto c_vec = + load_vector(c + c_idx, index_x, shape_x, c_stride_x, T(0)); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } -template +template __global__ void ternary_g( const bool* a, const T* b, const T* c, T* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides a_strides, const __grid_constant__ Strides b_strides, const __grid_constant__ Strides c_strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto [a_idx, b_idx, c_idx] = elem_to_loc( - index, - shape.data(), - a_strides.data(), - b_strides.data(), - c_strides.data(), - ndim); - out[index] = Op{}(a[a_idx], b[b_idx], c[c_idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto a_stride_x = a_strides[ndim - 1]; + auto b_stride_x = b_strides[ndim - 1]; + auto c_stride_x = c_strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto [a_idx, b_idx, c_idx] = elem_to_loc( + index_rest * shape_x, + shape.data(), + a_strides.data(), + b_strides.data(), + c_strides.data(), + ndim); + auto a_vec = + load_vector(a + a_idx, index_x, shape_x, a_stride_x, false); + auto b_vec = + load_vector(b + b_idx, index_x, shape_x, b_stride_x, T(0)); + auto c_vec = + load_vector(c + c_idx, index_x, shape_x, c_stride_x, T(0)); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(a_vec[i], b_vec[i], c_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } } // namespace cu @@ -123,36 +169,55 @@ void ternary_op_gpu_inplace( auto& b_strides = strides[1]; auto& c_strides = strides[2]; int ndim = shape.size(); + int work_per_thread = 1; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + if (ndim <= 3) { dispatch_1_2_3(ndim, [&](auto dims_constant) { - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = + cu::ternary_g_nd; + if (work_per_thread == 4) { + kernel = + cu::ternary_g_nd; + } encoder.add_kernel_node( - cu::ternary_g_nd, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), c.data(), out.data(), - out.size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides), const_param(c_strides)); }); } else { - auto [num_blocks, block_dims] = get_launch_args(out, large()); + auto kernel = cu::ternary_g; + if (work_per_thread == 4) { + kernel = cu::ternary_g; + } encoder.add_kernel_node( - cu::ternary_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, a.data(), b.data(), c.data(), out.data(), - out.data_size(), + rest, const_param(shape), const_param(a_strides), const_param(b_strides), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 96888da97..4102dbfb3 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -37,19 +37,36 @@ __global__ void unary_v(const In* in, Out* out, IdxT size) { } } -template +template __global__ void unary_g( const In* in, Out* out, - IdxT size, + IdxT size_rest, const __grid_constant__ Shape shape, const __grid_constant__ Strides strides, int ndim) { - IdxT index = cg::this_grid().thread_rank(); - if (index < size) { - auto idx = elem_to_loc(index, shape.data(), strides.data(), ndim); - out[index] = Op{}(in[idx]); + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; } + + auto shape_x = shape[ndim - 1]; + auto stride_x = strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = + elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); + auto in_vec = + load_vector(in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); } template @@ -127,8 +144,7 @@ void unary_op_gpu_inplace( using OutType = cuda_type_t; if (contig) { using IdxT = std::conditional_t; - // TODO: Choose optimized value based on type size. - constexpr int N_READS = 4; + constexpr int N_READS = 16 / sizeof(OutType); auto [num_blocks, block_dims] = get_launch_args( out.data_size(), out.shape(), out.strides(), large, N_READS); encoder.add_kernel_node( @@ -142,18 +158,30 @@ void unary_op_gpu_inplace( } else { using IdxT = std::conditional_t; auto [shape, strides] = collapse_contiguous_dims(in); - auto [num_blocks, block_dims] = get_launch_args(out, large); + auto ndim = shape.size(); + int work_per_thread = 1; + auto kernel = cu::unary_g; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + kernel = cu::unary_g; + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); encoder.add_kernel_node( - cu::unary_g, - num_blocks, + kernel, + {num_blocks_x, num_blocks_y}, block_dims, 0, in.data(), out.data(), - out.data_size(), + rest, const_param(shape), const_param(strides), - shape.size()); + ndim); } }); } else { diff --git a/mlx/backend/cuda/unary/CMakeLists.txt b/mlx/backend/cuda/unary/CMakeLists.txt new file mode 100644 index 000000000..532c5645e --- /dev/null +++ b/mlx/backend/cuda/unary/CMakeLists.txt @@ -0,0 +1,34 @@ +target_sources( + mlx + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/abs.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccos.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arccosh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsin.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arcsinh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctanh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_invert.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/ceil.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/conjugate.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cos.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cosh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/erf_inv.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/exp.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/expm1.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/floor.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/imag.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log1p.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_not.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/negative.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/real.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/round.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sigmoid.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sign.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sin.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sinh.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/sqrt.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/square.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tan.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/tanh.cu) diff --git a/mlx/backend/cuda/unary/abs.cu b/mlx/backend/cuda/unary/abs.cu new file mode 100644 index 000000000..90b197d21 --- /dev/null +++ b/mlx/backend/cuda/unary/abs.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Abs) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arccos.cu b/mlx/backend/cuda/unary/arccos.cu new file mode 100644 index 000000000..38849970d --- /dev/null +++ b/mlx/backend/cuda/unary/arccos.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcCos) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arccosh.cu b/mlx/backend/cuda/unary/arccosh.cu new file mode 100644 index 000000000..0ef0738a4 --- /dev/null +++ b/mlx/backend/cuda/unary/arccosh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcCosh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arcsin.cu b/mlx/backend/cuda/unary/arcsin.cu new file mode 100644 index 000000000..07956ee9b --- /dev/null +++ b/mlx/backend/cuda/unary/arcsin.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcSin) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arcsinh.cu b/mlx/backend/cuda/unary/arcsinh.cu new file mode 100644 index 000000000..a7ab63e17 --- /dev/null +++ b/mlx/backend/cuda/unary/arcsinh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcSinh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arctan.cu b/mlx/backend/cuda/unary/arctan.cu new file mode 100644 index 000000000..78639afaa --- /dev/null +++ b/mlx/backend/cuda/unary/arctan.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcTan) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/arctanh.cu b/mlx/backend/cuda/unary/arctanh.cu new file mode 100644 index 000000000..488268c9e --- /dev/null +++ b/mlx/backend/cuda/unary/arctanh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ArcTanh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/bitwise_invert.cu b/mlx/backend/cuda/unary/bitwise_invert.cu new file mode 100644 index 000000000..77b88f30f --- /dev/null +++ b/mlx/backend/cuda/unary/bitwise_invert.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(BitwiseInvert) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/ceil.cu b/mlx/backend/cuda/unary/ceil.cu new file mode 100644 index 000000000..5ee300ffe --- /dev/null +++ b/mlx/backend/cuda/unary/ceil.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Ceil) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/conjugate.cu b/mlx/backend/cuda/unary/conjugate.cu new file mode 100644 index 000000000..1d1d60e77 --- /dev/null +++ b/mlx/backend/cuda/unary/conjugate.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Conjugate) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/cos.cu b/mlx/backend/cuda/unary/cos.cu new file mode 100644 index 000000000..cfceb86ab --- /dev/null +++ b/mlx/backend/cuda/unary/cos.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Cos) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/cosh.cu b/mlx/backend/cuda/unary/cosh.cu new file mode 100644 index 000000000..d5fcc7081 --- /dev/null +++ b/mlx/backend/cuda/unary/cosh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Cosh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/erf.cu b/mlx/backend/cuda/unary/erf.cu new file mode 100644 index 000000000..c7859322b --- /dev/null +++ b/mlx/backend/cuda/unary/erf.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Erf) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/erf_inv.cu b/mlx/backend/cuda/unary/erf_inv.cu new file mode 100644 index 000000000..16bbaba19 --- /dev/null +++ b/mlx/backend/cuda/unary/erf_inv.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(ErfInv) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/exp.cu b/mlx/backend/cuda/unary/exp.cu new file mode 100644 index 000000000..5a566691d --- /dev/null +++ b/mlx/backend/cuda/unary/exp.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Exp) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/expm1.cu b/mlx/backend/cuda/unary/expm1.cu new file mode 100644 index 000000000..15e6ce445 --- /dev/null +++ b/mlx/backend/cuda/unary/expm1.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Expm1) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/floor.cu b/mlx/backend/cuda/unary/floor.cu new file mode 100644 index 000000000..a8c7ab0bb --- /dev/null +++ b/mlx/backend/cuda/unary/floor.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Floor) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/imag.cu b/mlx/backend/cuda/unary/imag.cu new file mode 100644 index 000000000..9e3c05c3b --- /dev/null +++ b/mlx/backend/cuda/unary/imag.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Imag) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/log.cu b/mlx/backend/cuda/unary/log.cu new file mode 100644 index 000000000..1fd2aa680 --- /dev/null +++ b/mlx/backend/cuda/unary/log.cu @@ -0,0 +1,21 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +void Log::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Log::eval_gpu"); + auto& s = out.primitive().stream(); + switch (base_) { + case Base::e: + unary_op_gpu(inputs, out, name(), s); + break; + case Base::two: + unary_op_gpu(inputs, out, name(), s); + break; + case Base::ten: + unary_op_gpu(inputs, out, name(), s); + break; + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/log1p.cu b/mlx/backend/cuda/unary/log1p.cu new file mode 100644 index 000000000..5396c3da0 --- /dev/null +++ b/mlx/backend/cuda/unary/log1p.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Log1p) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/logical_not.cu b/mlx/backend/cuda/unary/logical_not.cu new file mode 100644 index 000000000..7f398707f --- /dev/null +++ b/mlx/backend/cuda/unary/logical_not.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(LogicalNot) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/negative.cu b/mlx/backend/cuda/unary/negative.cu new file mode 100644 index 000000000..9c7e576ec --- /dev/null +++ b/mlx/backend/cuda/unary/negative.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Negative) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/real.cu b/mlx/backend/cuda/unary/real.cu new file mode 100644 index 000000000..361ffd3f9 --- /dev/null +++ b/mlx/backend/cuda/unary/real.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Real) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/round.cu b/mlx/backend/cuda/unary/round.cu new file mode 100644 index 000000000..4e80fdb60 --- /dev/null +++ b/mlx/backend/cuda/unary/round.cu @@ -0,0 +1,18 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +void Round::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Round::eval_gpu"); + assert(inputs.size() == 1); + const auto& in = inputs[0]; + auto& s = out.primitive().stream(); + if (issubdtype(in.dtype(), inexact)) { + unary_op_gpu(inputs, out, name(), s); + } else { + // No-op integer types + out.copy_shared_buffer(in); + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sigmoid.cu b/mlx/backend/cuda/unary/sigmoid.cu new file mode 100644 index 000000000..3d943726c --- /dev/null +++ b/mlx/backend/cuda/unary/sigmoid.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Sigmoid) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sign.cu b/mlx/backend/cuda/unary/sign.cu new file mode 100644 index 000000000..d586d8275 --- /dev/null +++ b/mlx/backend/cuda/unary/sign.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Sign) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sin.cu b/mlx/backend/cuda/unary/sin.cu new file mode 100644 index 000000000..47a5adc84 --- /dev/null +++ b/mlx/backend/cuda/unary/sin.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Sin) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sinh.cu b/mlx/backend/cuda/unary/sinh.cu new file mode 100644 index 000000000..7a73b7fd4 --- /dev/null +++ b/mlx/backend/cuda/unary/sinh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Sinh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/sqrt.cu b/mlx/backend/cuda/unary/sqrt.cu new file mode 100644 index 000000000..21f5f08f2 --- /dev/null +++ b/mlx/backend/cuda/unary/sqrt.cu @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Sqrt::eval_gpu"); + auto& s = out.primitive().stream(); + if (recip_) { + unary_op_gpu(inputs, out, "Rsqrt", s); + } else { + unary_op_gpu(inputs, out, "Sqrt", s); + } +} +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/square.cu b/mlx/backend/cuda/unary/square.cu new file mode 100644 index 000000000..bbb5f5130 --- /dev/null +++ b/mlx/backend/cuda/unary/square.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Square) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/tan.cu b/mlx/backend/cuda/unary/tan.cu new file mode 100644 index 000000000..3039dcdc1 --- /dev/null +++ b/mlx/backend/cuda/unary/tan.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Tan) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/tanh.cu b/mlx/backend/cuda/unary/tanh.cu new file mode 100644 index 000000000..ae69a51b5 --- /dev/null +++ b/mlx/backend/cuda/unary/tanh.cu @@ -0,0 +1,7 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/unary/unary.cuh" + +namespace mlx::core { +UNARY_GPU(Tanh) +} // namespace mlx::core diff --git a/mlx/backend/cuda/unary/unary.cuh b/mlx/backend/cuda/unary/unary.cuh new file mode 100644 index 000000000..a20e119ca --- /dev/null +++ b/mlx/backend/cuda/unary/unary.cuh @@ -0,0 +1,215 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/unary_ops.cuh" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void unary_v(const In* in, Out* out, IdxT size) { + IdxT index = cg::this_grid().thread_rank(); + + if ((index + 1) * N_READS > size) { + for (IdxT i = index * N_READS; i < size; ++i) { + out[i] = Op{}(in[i]); + } + } else { + auto in_vec = load_vector(in, index); + + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(in_vec[i]); + } + + store_vector(out, index, out_vec); + } +} + +template +__global__ void unary_g( + const In* in, + Out* out, + IdxT size_rest, + const __grid_constant__ Shape shape, + const __grid_constant__ Strides strides, + int ndim) { + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + IdxT index_rest = + grid.block_index().y * block.dim_threads().y + block.thread_index().y; + if (index_rest >= size_rest) { + return; + } + + auto shape_x = shape[ndim - 1]; + auto stride_x = strides[ndim - 1]; + IdxT index_x = + grid.block_index().x * block.dim_threads().x + block.thread_index().x; + auto idx = + elem_to_loc(index_rest * shape_x, shape.data(), strides.data(), ndim); + auto in_vec = + load_vector(in + idx, index_x, shape_x, stride_x, In(0)); + AlignedVector out_vec; +#pragma unroll + for (int i = 0; i < N_READS; ++i) { + out_vec[i] = Op{}(in_vec[i]); + } + store_vector(out + shape_x * index_rest, index_x, out_vec, shape_x); +} + +template +constexpr bool supports_unary_op() { + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return std::is_same_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_floating_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_integral_v && + !std::is_same_v; + } + if (std::is_same_v || std::is_same_v) { + return std::is_same_v && !mlx::core::is_complex_v; + } + if (std::is_same_v) { + return std::is_same_v && mlx::core::is_complex_v; + } + if (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return std::is_same_v && is_inexact_v; + } + if (std::is_same_v || std::is_same_v) { + return mlx::core::is_complex_v && std::is_same_v; + } + if (std::is_same_v) { + return std::is_same_v && std::is_same_v; + } + return false; +} + +} // namespace cu + +template +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + auto& in = inputs[0]; + if (in.size() == 0) { + return; + } + bool contig = in.flags().contiguous; + bool large; + if (!contig) { + large = in.data_size() > INT32_MAX || out.size() > INT32_MAX; + } else { + large = in.data_size() > UINT32_MAX; + } + + auto& encoder = cu::get_command_encoder(s); + encoder.set_input_array(in); + encoder.set_output_array(out); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); + if constexpr (cu::supports_unary_op()) { + dispatch_bool(large, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + if (contig) { + using IdxT = std::conditional_t; + constexpr int N_READS = 16 / sizeof(OutType); + auto [num_blocks, block_dims] = get_launch_args( + out.data_size(), out.shape(), out.strides(), large, N_READS); + encoder.add_kernel_node( + cu::unary_v, + num_blocks, + block_dims, + 0, + in.data(), + out.data(), + out.data_size()); + } else { + using IdxT = std::conditional_t; + auto [shape, strides] = collapse_contiguous_dims(in); + auto ndim = shape.size(); + int work_per_thread = 1; + auto kernel = cu::unary_g; + auto dim0 = ndim > 0 ? shape.back() : 1; + auto rest = out.size() / dim0; + if (dim0 >= 4) { + kernel = cu::unary_g; + work_per_thread = 4; + } + dim0 = (dim0 + work_per_thread - 1) / work_per_thread; + auto block_dims = get_block_dims(dim0, rest, 1); + uint32_t num_blocks_x = cuda::ceil_div(dim0, block_dims.x); + uint32_t num_blocks_y = cuda::ceil_div(rest, block_dims.y); + encoder.add_kernel_node( + kernel, + {num_blocks_x, num_blocks_y}, + block_dims, + 0, + in.data(), + out.data(), + rest, + const_param(shape), + const_param(strides), + ndim); + } + }); + } else { + throw std::runtime_error(fmt::format( + "Can not do unary op {} on input of {} with output of {}.", + op, + dtype_to_string(in.dtype()), + dtype_to_string(out.dtype()))); + } + }); + }); +} + +template +void unary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s) { + set_unary_output_data(inputs[0], out); + unary_op_gpu_inplace(inputs, out, op, s); +} + +#define UNARY_GPU(func) \ + void func::eval_gpu(const std::vector& inputs, array& out) { \ + nvtx3::scoped_range r(#func "::eval_gpu"); \ + auto& s = out.primitive().stream(); \ + unary_op_gpu(inputs, out, name(), s); \ + } + +} // namespace mlx::core From 4abb218d21e36f0d3dd9fc47e51387637e96c1ee Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 16 Aug 2025 07:57:30 +0900 Subject: [PATCH 08/20] The naive_conv_2d is no longer used (#2496) --- mlx/backend/metal/kernels/conv.metal | 109 --------------------------- 1 file changed, 109 deletions(-) diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 620352144..e169ade71 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -166,115 +166,6 @@ instantiate_naive_unfold_nd_dims(float32, float); instantiate_naive_unfold_nd_dims(float16, half); instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t); -/////////////////////////////////////////////////////////////////////////////// -/// Slow and naive conv2d kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const int BC = 16> -[[kernel]] void naive_conv_2d( - const device T* in [[buffer(0)]], - const device T* wt [[buffer(1)]], - device T* out [[buffer(2)]], - const constant MLXConvParams<2>& params [[buffer(3)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)simd_gid; - (void)simd_lid; - - out += tid.z * params.out_strides[0]; - in += tid.z * params.in_strides[0]; - - int out_o = tid.y * BN * TN + lid.y * TN; - int out_hw = tid.x * BM * TM + lid.x * TM; - - int out_h[TM]; - int out_w[TN]; - - for (int m = 0; m < TM; ++m) { - int mm = (out_hw + m); - out_h[m] = mm / params.oS[1]; - out_w[m] = mm % params.oS[1]; - } - - T in_local[TM]; - T wt_local[TN]; - T out_local[TM * TN] = {T(0)}; - - for (int h = 0; h < params.wS[0]; ++h) { - for (int w = 0; w < params.wS[1]; ++w) { - for (int c = 0; c < params.C; ++c) { - // Local in - for (int m = 0; m < TM; m++) { - int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0]; - int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1]; - - bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; - in_local[m] = valid - ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] - : T(0); - } - - // Load weight - for (int n = 0; n < TN; ++n) { - int o = out_o + n; - wt_local[n] = o < params.O - ? wt[o * params.wt_strides[0] + h * params.wt_strides[1] + - w * params.wt_strides[2] + c] - : T(0); - } - - // Accumulate - for (int m = 0; m < TM; ++m) { - for (int n = 0; n < TN; ++n) { - out_local[m * TN + n] += in_local[m] * wt_local[n]; - } - } - } - } - } - - for (int m = 0; m < TM; ++m) { - for (int n = 0; n < TN; ++n) { - if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && - (out_o + n) < params.O) - out[out_h[m] * params.out_strides[1] + - out_w[m] * params.out_strides[2] + out_o + n] = - out_local[m * TN + n]; - } - } -} - -// Instantiations - -#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ - template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \ - "_tn" #tn)]] [[kernel]] void \ - naive_conv_2d( \ - const device itype* in [[buffer(0)]], \ - const device itype* wt [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant MLXConvParams<2>& params [[buffer(3)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - -#define instantiate_naive_conv_2d_blocks(name, itype) \ - instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \ - instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4) - -instantiate_naive_conv_2d_blocks(float32, float); -instantiate_naive_conv_2d_blocks(float16, half); -instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); - /////////////////////////////////////////////////////////////////////////////// /// Depthwise convolution kernels /////////////////////////////////////////////////////////////////////////////// From 888b13ed6329667a5e07a2e38297d9f2dda33aa9 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 16 Aug 2025 08:17:24 +0900 Subject: [PATCH 09/20] Remove the hack around SmallVector in cpu compile (#2494) --- mlx/backend/cpu/compiled.cpp | 40 +++++++++++++++--------------------- mlx/small_vector.h | 2 +- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/mlx/backend/cpu/compiled.cpp b/mlx/backend/cpu/compiled.cpp index 8aa296619..a21819ac0 100644 --- a/mlx/backend/cpu/compiled.cpp +++ b/mlx/backend/cpu/compiled.cpp @@ -157,10 +157,12 @@ inline void build_kernel( #endif // Start the kernel - os << "void " << kernel_name << "(void** args) {" << std::endl; + os << "void " << kernel_name + << "(int* shape, int64_t** strides, void** args) {" << std::endl; // Add the input arguments int cnt = 0; + int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { // Skip constants from the input list if (is_constant(i)) { @@ -175,8 +177,8 @@ inline void build_kernel( << "];" << std::endl; // Scalars and contiguous need no strides if (!is_scalar(x) && !contiguous) { - os << " const size_t* " << xname << "_strides = (size_t*)args[" << cnt++ - << "];" << std::endl; + os << " const int64_t* " << xname << "_strides = strides[" + << strides_index++ << "];" << std::endl; } } @@ -186,10 +188,8 @@ inline void build_kernel( os << " " << tstr << "* " << namer.get_name(x) << " = (" << tstr << "*)args[" << cnt++ << "];" << std::endl; } - // Add output strides and shape to extract the indices. - if (!contiguous) { - os << " const int* shape = (int*)args[" << cnt++ << "];" << std::endl; - } else { + // Add output size + if (contiguous) { os << " const size_t size = (size_t)args[" << cnt++ << "];" << std::endl; } @@ -288,17 +288,8 @@ void Compiled::eval_cpu( auto [contiguous, shape, strides] = compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_); - // Force allocating shape/strides on heap so we can take their data() first - // and then std::move them. - // TODO: Refactor code to avoid heap allocation. - shape.grow(); - for (auto& s : strides) { - s.grow(); - } - // Collect function input arguments. std::vector args; - int strides_index = 1; for (size_t i = 0; i < inputs.size(); ++i) { if (is_constant_(i)) { continue; @@ -306,9 +297,6 @@ void Compiled::eval_cpu( const auto& x = inputs[i]; encoder.set_input_array(x); args.push_back((void*)x.data()); - if (!contiguous && !is_scalar(x)) { - args.push_back(strides[strides_index++].data()); - } } // Get the kernel name from the lib @@ -343,16 +331,20 @@ void Compiled::eval_cpu( args.push_back(x.data()); encoder.set_output_array(x); } - if (!contiguous) { - args.push_back((void*)shape.data()); - } else { + if (contiguous) { args.push_back((void*)outputs[0].data_size()); } - auto fun = (void (*)(void**))fn_ptr; + auto fun = reinterpret_cast(fn_ptr); encoder.dispatch([fun, args = std::move(args), strides = std::move(strides), - shape = std::move(shape)]() mutable { fun(args.data()); }); + shape = std::move(shape)]() mutable { + SmallVector strides_ptrs; + for (auto& s : strides) { + strides_ptrs.push_back(s.data()); + } + fun(shape.data(), strides_ptrs.data(), args.data()); + }); } } // namespace mlx::core diff --git a/mlx/small_vector.h b/mlx/small_vector.h index fc4c1f06c..0a3371058 100644 --- a/mlx/small_vector.h +++ b/mlx/small_vector.h @@ -440,6 +440,7 @@ class SmallVector { end_ = begin_; } + private: // Grows the backing store by a factor of two, and at least to {min_capacity}. // TODO: Move to private after removing external code using this method. MLX_NOINLINE void grow(size_t min_capacity = 0) { @@ -469,7 +470,6 @@ class SmallVector { end_of_storage_ = new_storage + new_capacity; } - private: MLX_NOINLINE void free_storage() { std::destroy_n(begin_, end_ - begin_); if (is_big()) { From 37b440faa89fffc0406c08990f256f3054adfd26 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 16 Aug 2025 09:01:10 +0900 Subject: [PATCH 10/20] Clean up code handling both std::vector and SmallVector (#2493) --- mlx/backend/metal/device.h | 20 +++++--------------- mlx/small_vector.h | 12 ++++++++++++ mlx/utils.cpp | 37 ------------------------------------- mlx/utils.h | 17 +++++++++++++---- 4 files changed, 30 insertions(+), 56 deletions(-) diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 00df2ddeb..fefb7cdc0 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -60,22 +60,12 @@ struct CommandEncoder { enc_->updateFence(fence); } - template - void set_vector_bytes(const SmallVector& vec, size_t nelems, int idx) { - enc_->setBytes(vec.data(), nelems * sizeof(T), idx); + template >> + void set_vector_bytes(const Vec& vec, size_t nelems, int idx) { + enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx); } - template - void set_vector_bytes(const SmallVector& vec, int idx) { - return set_vector_bytes(vec, vec.size(), idx); - } - - // TODO: Code is duplicated but they should be deleted soon. - template - void set_vector_bytes(const std::vector& vec, size_t nelems, int idx) { - enc_->setBytes(vec.data(), nelems * sizeof(T), idx); - } - template - void set_vector_bytes(const std::vector& vec, int idx) { + template >> + void set_vector_bytes(const Vec& vec, int idx) { return set_vector_bytes(vec, vec.size(), idx); } diff --git a/mlx/small_vector.h b/mlx/small_vector.h index 0a3371058..143101c82 100644 --- a/mlx/small_vector.h +++ b/mlx/small_vector.h @@ -519,6 +519,18 @@ class SmallVector { std::is_trivially_destructible::value; }; +template +struct is_vector : std::false_type {}; + +template +struct is_vector> : std::true_type {}; + +template +struct is_vector> : std::true_type {}; + +template +inline constexpr bool is_vector_v = is_vector::value; + #undef MLX_HAS_BUILTIN #undef MLX_HAS_ATTRIBUTE #undef MLX_LIKELY diff --git a/mlx/utils.cpp b/mlx/utils.cpp index eac18239e..2a850d9f9 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -259,43 +259,6 @@ std::ostream& operator<<(std::ostream& os, array a) { return os; } -std::ostream& operator<<(std::ostream& os, const SmallVector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -std::ostream& operator<<(std::ostream& os, const SmallVector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -// TODO: Code is duplicated but they should be deleted soon. -std::ostream& operator<<(std::ostream& os, const std::vector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - -std::ostream& operator<<(std::ostream& os, const std::vector& v) { - os << "("; - for (int i = 0; i < v.size(); ++i) { - os << v[i] << ((i == v.size() - 1) ? "" : ","); - } - os << ")"; - return os; -} - namespace env { int get_var(const char* name, int default_value) { diff --git a/mlx/utils.h b/mlx/utils.h index 451393540..076842f78 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -100,10 +100,6 @@ std::ostream& operator<<(std::ostream& os, const Stream& s); std::ostream& operator<<(std::ostream& os, const Dtype& d); std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); std::ostream& operator<<(std::ostream& os, array a); -std::ostream& operator<<(std::ostream& os, const SmallVector& v); -std::ostream& operator<<(std::ostream& os, const SmallVector& v); -std::ostream& operator<<(std::ostream& os, const std::vector& v); -std::ostream& operator<<(std::ostream& os, const std::vector& v); inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; } @@ -114,6 +110,19 @@ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } +template >> +inline std::ostream& operator<<(std::ostream& os, const Vec& v) { + os << "("; + for (auto it = v.begin(); it != v.end(); ++it) { + os << *it; + if (it != std::prev(v.end())) { + os << ","; + } + } + os << ")"; + return os; +} + inline bool is_power_of_2(int n) { return ((n & (n - 1)) == 0) && n != 0; } From 1ba18ff7d97f4610d928026cba8814408e3526c2 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 16 Aug 2025 10:09:18 +0900 Subject: [PATCH 11/20] [CUDA] Fix conv grads with groups (#2495) * Put reshape utils in one file * [CUDA] Fix conv grads with groups * Put the reshape utils in gpu/copy.h --- mlx/backend/common/utils.cpp | 27 -------------- mlx/backend/common/utils.h | 3 -- mlx/backend/cuda/conv.cpp | 46 +++++++++++++++++++++--- mlx/backend/cuda/sort.cu | 1 - mlx/backend/gpu/copy.cpp | 66 ++++++++++++++++++++++++++++++++++ mlx/backend/gpu/copy.h | 8 +++++ mlx/backend/gpu/primitives.cpp | 29 ++------------- python/tests/cuda_skip.py | 1 - 8 files changed, 119 insertions(+), 62 deletions(-) diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 4c9e39dc6..ae169e35e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -228,31 +228,4 @@ std::pair get_grid_and_block_common(int dim0, int dim1, int dim2) { std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); } -array swapaxes_in_eval(const array& x, int axis1, int axis2) { - int ndim = x.ndim(); - if (axis1 < 0) { - axis1 += ndim; - } - if (axis2 < 0) { - axis2 += ndim; - } - - auto shape = x.shape(); - std::swap(shape[axis1], shape[axis2]); - auto strides = x.strides(); - std::swap(strides[axis1], strides[axis2]); - - auto [data_size, row_contiguous, col_contiguous] = - check_contiguity(shape, strides); - bool contiguous = data_size == x.data_size(); - - array out(std::move(shape), x.dtype(), nullptr, {}); - out.copy_shared_buffer( - x, - std::move(strides), - {contiguous, row_contiguous, col_contiguous}, - x.data_size()); - return out; -} - } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index db0da5e10..1b6902ff3 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -196,9 +196,6 @@ void shared_buffer_reshape( const Strides& out_strides, array& out); -// Like the swapaxes op but safe to call in eval_gpu. -array swapaxes_in_eval(const array& x, int axis1, int axis2); - template inline SmallVector remove_index(SmallVector vec, size_t index) { vec.erase(std::next(vec.begin(), index)); diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 1484e8c46..3d7ef60bc 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -336,6 +336,42 @@ std::optional build_op_graph( } } +// Transpose from (C_out, H, W, C_in / groups) to (C_in, H, W, C_out / groups). +array group_transpose( + const array& x, + int groups, + int group_dim, + int axis1, + int axis2, + Stream s) { + if (groups == 1) { + return swapaxes_in_eval(x, axis1, axis2); + } + int ndim = x.ndim(); + if (group_dim < 0) { + group_dim += ndim; + } + if (axis1 < 0) { + axis1 += ndim; + } + if (axis2 < 0) { + axis2 += ndim; + } + if (group_dim <= axis1) { + axis1 += 1; + } + if (group_dim <= axis2) { + axis2 += 1; + } + auto shape = x.shape(); + shape.insert(shape.begin() + group_dim, groups); + shape[group_dim + 1] = shape[group_dim + 1] / groups; + array x_trans = reshape_in_eval(x, std::move(shape), s); + x_trans = swapaxes_in_eval(x_trans, axis1, axis2); + x_trans = flatten_in_eval(x_trans, group_dim, group_dim + 1, s); + return x_trans; +} + // Do necessary transposes and copies to prepare the inputs and outputs for // building the cuDNN conv op. It is safe to be called multiple times in one // eval_gpu, with cost of possible redundant copies. @@ -345,13 +381,14 @@ std::tuple prepare_args( array in, array wt, array out, + int groups, Stream s) { // Transpose the args depending on the backend type. // TODO: Handle groups. if (backend_type == CONV_BACKWARD_INPUT) { - wt = swapaxes_in_eval(wt, 0, -1); + wt = group_transpose(wt, groups, 0, 0, -1, s); } else if (backend_type == CONV_BACKWARD_WEIGHT) { - in = swapaxes_in_eval(in, 0, -1); + in = group_transpose(in, groups, -1, 0, -1, s); wt = swapaxes_in_eval(wt, 0, -1); // Create a contiguous array that shares the data with |out|, but with dim // C_in and C_out swapped. @@ -457,7 +494,8 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { get_alignment(out)}; if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { auto& [backend_type, plan] = it->second; - std::tie(in, wt, out) = prepare_args(encoder, backend_type, in, wt, out, s); + std::tie(in, wt, out) = + prepare_args(encoder, backend_type, in, wt, out, groups_, s); register_args(encoder, backend_type, in, wt, out, out_); auto [x, w, y] = dispatch_args(backend_type, in, wt, out); if (!execute_plan(encoder, plan, x, w, y)) { @@ -490,7 +528,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { std::optional op_graph; for (auto try_backend : try_backends) { auto [in_copy, wt_copy, out_copy] = - prepare_args(encoder, try_backend, in, wt, out, s); + prepare_args(encoder, try_backend, in, wt, out, groups_, s); auto [x, w, y] = dispatch_args(try_backend, in_copy, wt_copy, out_copy); auto [stride, padding_lo, padding_hi, dilation] = get_conv_op_settings( try_backend, diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 5bbd72fd5..e81ae12eb 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -1,6 +1,5 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/common/utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp index 4556f7d98..472ee486b 100644 --- a/mlx/backend/gpu/copy.cpp +++ b/mlx/backend/gpu/copy.cpp @@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) { return arr_copy; } +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) { + int ndim = x.ndim(); + if (start_axis < 0) { + start_axis += ndim; + } + if (end_axis < 0) { + end_axis += ndim; + } + start_axis = std::max(0, start_axis); + end_axis = std::min(ndim - 1, end_axis); + + return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s); +} + +array reshape_in_eval(const array& x, Shape shape, Stream s) { + array out(std::move(shape), x.dtype(), nullptr, {}); + reshape_gpu(x, out, s); + return out; +} + +array swapaxes_in_eval(const array& x, int axis1, int axis2) { + int ndim = x.ndim(); + if (axis1 < 0) { + axis1 += ndim; + } + if (axis2 < 0) { + axis2 += ndim; + } + + auto shape = x.shape(); + std::swap(shape[axis1], shape[axis2]); + auto strides = x.strides(); + std::swap(strides[axis1], strides[axis2]); + + auto [data_size, row_contiguous, col_contiguous] = + check_contiguity(shape, strides); + bool contiguous = data_size == x.data_size(); + + array out(std::move(shape), x.dtype(), nullptr, {}); + out.copy_shared_buffer( + x, + std::move(strides), + {contiguous, row_contiguous, col_contiguous}, + x.data_size()); + return out; +} + } // namespace mlx::core diff --git a/mlx/backend/gpu/copy.h b/mlx/backend/gpu/copy.h index f01fe9fda..274250202 100644 --- a/mlx/backend/gpu/copy.h +++ b/mlx/backend/gpu/copy.h @@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s); // Return a contiguous array with same shape that copies the data of |arr|. array contiguous_copy_gpu(const array& arr, const Stream& s); +// Copy data from |in| and transpose to |out|'s shape. +void reshape_gpu(const array& in, array& out, Stream s); + +// Like the normal ops but safe to call in eval_gpu. +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s); +array reshape_in_eval(const array& x, Shape shape, Stream s); +array swapaxes_in_eval(const array& x, int axis1, int axis2); + } // namespace mlx::core diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 56d389b4f..6017879a5 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -20,29 +20,6 @@ namespace mlx::core { -namespace { - -void reshape(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - -} // namespace - void AsStrided::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("AsStrided::eval_gpu"); eval(inputs, out); @@ -124,7 +101,7 @@ void Full::eval_gpu(const std::vector& inputs, array& out) { void Flatten::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Flatten::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { @@ -150,7 +127,7 @@ void Pad::eval_gpu(const std::vector& inputs, array& out) { void Reshape::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Reshape::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void Split::eval_gpu( @@ -224,7 +201,7 @@ void Transpose::eval_gpu(const std::vector& inputs, array& out) { void Unflatten::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Unflatten::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void View::eval_gpu(const std::vector& inputs, array& out) { diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 0f57100e8..c635de9ad 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -17,7 +17,6 @@ cuda_skip = { "TestConv.test_1d_conv_with_2d", "TestConv.test_conv_1d_groups_flipped", "TestConv.test_conv_general_flip_grad", - "TestConv.test_conv_groups_grad", "TestConv.test_torch_conv_2D", "TestConv.test_torch_conv_depthwise", "TestConv.test_torch_conv_general", From c422050ca734d3ba1b13547c02b01b9a68ae7663 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sun, 17 Aug 2025 19:13:01 +0900 Subject: [PATCH 12/20] Update cuDNN Frontend to v1.14 (#2505) --- mlx/backend/cuda/CMakeLists.txt | 2 +- mlx/backend/cuda/conv.cpp | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 0d526400d..38708be5e 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -149,7 +149,7 @@ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver) FetchContent_Declare( cudnn GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git - GIT_TAG v1.12.1 + GIT_TAG v1.14.0 GIT_SHALLOW TRUE EXCLUDE_FROM_ALL) set(CUDNN_FRONTEND_SKIP_JSON_LIB ON) diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 3d7ef60bc..29d4803a8 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -7,9 +7,6 @@ #include "mlx/dtype_utils.h" #include "mlx/primitives.h" -// cudnn_frontend.h redefines this macro. -#undef CHECK_CUDA_ERROR - #include #include #include From 73f22d622633acf3dcd2a2e043fff1947916db13 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 17 Aug 2025 08:42:20 -0700 Subject: [PATCH 13/20] Ensure small sort doesn't use indices if not argsort (#2506) --- mlx/backend/metal/kernels/sort.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/kernels/sort.h b/mlx/backend/metal/kernels/sort.h index b067150d8..5823e4300 100644 --- a/mlx/backend/metal/kernels/sort.h +++ b/mlx/backend/metal/kernels/sort.h @@ -45,7 +45,9 @@ struct ThreadSort { for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { if (op(vals[j + 1], vals[j])) { thread_swap(vals[j + 1], vals[j]); - thread_swap(idxs[j + 1], idxs[j]); + if (ARG_SORT) { + thread_swap(idxs[j + 1], idxs[j]); + } } } } @@ -111,7 +113,9 @@ struct BlockMergeSort { bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); vals[i] = pred ? b : a; - idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + if (ARG_SORT) { + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + } b_idx += short(pred); a_idx += short(!pred); From 1df988799834e1ce0ad96dd5bb71c3302aa212fc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 17 Aug 2025 08:42:33 -0700 Subject: [PATCH 14/20] Ensure no oob read in gemv_masked (#2508) --- mlx/backend/metal/kernels/gemv_masked.h | 96 +++++++++++++------------ 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/mlx/backend/metal/kernels/gemv_masked.h b/mlx/backend/metal/kernels/gemv_masked.h index 75bc7354c..96b0c2821 100644 --- a/mlx/backend/metal/kernels/gemv_masked.h +++ b/mlx/backend/metal/kernels/gemv_masked.h @@ -262,36 +262,37 @@ struct GEMVKernel { vec_mask_offset += vec_mask_step; } - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - load_safe(in_vec, v_coeff, bn, in_size); - - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] *= block_scale; + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - } - // Per thread work loop - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + load_safe(in_vec, v_coeff, bn, in_size); - // Accumulate results + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } } } } @@ -544,31 +545,32 @@ struct GEMVTKernel { vec_mask_offset += vec_mask_step; } - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; if (has_mul_operand_mask) { - v_coeff[tm] *= block_scale; + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } } } } From c5fcd5b61b2483a157388dae12099e6c71c33713 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Aug 2025 06:45:59 -0700 Subject: [PATCH 15/20] fix custom kernel test (#2510) --- docs/src/dev/custom_metal_kernels.rst | 2 +- mlx/ops.cpp | 2 +- python/tests/test_fast.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/dev/custom_metal_kernels.rst b/docs/src/dev/custom_metal_kernels.rst index 873b1e544..1febe960a 100644 --- a/docs/src/dev/custom_metal_kernels.rst +++ b/docs/src/dev/custom_metal_kernels.rst @@ -128,6 +128,7 @@ relying on a copy from ``ensure_row_contiguous``: input_names=["inp"], output_names=["out"], source=source + ensure_row_contiguous=False, ) def exp_elementwise(a: mx.array): @@ -138,7 +139,6 @@ relying on a copy from ``ensure_row_contiguous``: threadgroup=(256, 1, 1), output_shapes=[a.shape], output_dtypes=[a.dtype], - ensure_row_contiguous=False, ) return outputs[0] diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 14e135deb..c8583c72f 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -2971,7 +2971,7 @@ array gather( } for (auto& x : indices) { if (x.dtype() == bool_) { - throw("[Gather] Boolean indices not supported."); + throw std::invalid_argument("[Gather] Boolean indices not supported."); } } diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 13c65de99..f79a62a15 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -639,12 +639,12 @@ class TestFast(mlx_tests.MLXTestCase): ], grid=(6, 1, 1), threadgroup=(2, 1, 1), - output_shapes=[(2, 2), (3, 2)], + output_shapes=[(3, 2), (3, 2)], output_dtypes=[mx.float32, mx.int32], stream=mx.gpu, ) - self.assertTrue(mx.allclose(out[0], mx.full((2, 2), 14.0484))) + self.assertTrue(mx.allclose(out[0], mx.full((3, 2), 14.0484))) self.assertTrue(mx.allclose(out[1], mx.full((3, 2), -2, dtype=mx.int32))) @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") From e7c6e1db82ff40b569b5b20d6d720538229d766a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Aug 2025 08:33:38 -0700 Subject: [PATCH 16/20] no segfault with uninitialized array.at (#2514) --- python/src/array.cpp | 15 +++++++++++++++ python/tests/test_array.py | 3 +++ 2 files changed, 18 insertions(+) diff --git a/python/src/array.cpp b/python/src/array.cpp index 22ef8e273..143d2e6f5 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -28,30 +28,45 @@ class ArrayAt { public: ArrayAt(mx::array x) : x_(std::move(x)) {} ArrayAt& set_indices(nb::object indices) { + initialized_ = true; indices_ = indices; return *this; } + void check_initialized() { + if (!initialized_) { + throw std::invalid_argument( + "Must give indices to array.at (e.g. `x.at[0].add(4)`)."); + } + } + mx::array add(const ScalarOrArray& v) { + check_initialized(); return mlx_add_item(x_, indices_, v); } mx::array subtract(const ScalarOrArray& v) { + check_initialized(); return mlx_subtract_item(x_, indices_, v); } mx::array multiply(const ScalarOrArray& v) { + check_initialized(); return mlx_multiply_item(x_, indices_, v); } mx::array divide(const ScalarOrArray& v) { + check_initialized(); return mlx_divide_item(x_, indices_, v); } mx::array maximum(const ScalarOrArray& v) { + check_initialized(); return mlx_maximum_item(x_, indices_, v); } mx::array minimum(const ScalarOrArray& v) { + check_initialized(); return mlx_minimum_item(x_, indices_, v); } private: mx::array x_; + bool initialized_{false}; nb::object indices_; }; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 3ab41bef7..ae1cb784f 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1365,6 +1365,9 @@ class TestArray(mlx_tests.MLXTestCase): def test_array_at(self): a = mx.array(1) + with self.assertRaises(ValueError): + a.at.add(1) + a = a.at[None].add(1) self.assertEqual(a.item(), 2) From cea93696109e24a5605516c5a4b48c030c15003a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Aug 2025 15:07:59 -0700 Subject: [PATCH 17/20] fix lapack svd (#2515) --- mlx/backend/cpu/lapack.h | 2 +- mlx/backend/cpu/svd.cpp | 38 ++++++-------------------------------- 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/mlx/backend/cpu/lapack.h b/mlx/backend/cpu/lapack.h index b242093ff..ce735f26c 100644 --- a/mlx/backend/cpu/lapack.h +++ b/mlx/backend/cpu/lapack.h @@ -47,7 +47,7 @@ INSTANTIATE_LAPACK_REAL(orgqr) INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(geev) INSTANTIATE_LAPACK_REAL(potrf) -INSTANTIATE_LAPACK_REAL(gesvdx) +INSTANTIATE_LAPACK_REAL(gesdd) INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(trtri) diff --git a/mlx/backend/cpu/svd.cpp b/mlx/backend/cpu/svd.cpp index 08ad444e1..6e57eb401 100644 --- a/mlx/backend/cpu/svd.cpp +++ b/mlx/backend/cpu/svd.cpp @@ -81,9 +81,7 @@ void svd_impl( // Vᵀ of shape N x N. (M x M in lapack). const int ldvt = M; - auto job_u = (u_ptr) ? "V" : "N"; - auto job_vt = (u_ptr) ? "V" : "N"; - static constexpr auto range = "A"; + auto jobz = (u_ptr) ? "A" : "N"; // Will contain the number of singular values after the call has returned. int ns = 0; @@ -91,30 +89,20 @@ void svd_impl( // Will contain the indices of eigenvectors that failed to converge (not // used here but required by lapack). - auto iwork = array::Data{allocator::malloc(sizeof(int) * 12 * K)}; + auto iwork = array::Data{allocator::malloc(sizeof(int) * 8 * K)}; static const int lwork_query = -1; - static const int ignored_int = 0; - static const T ignored_float = 0; - int info; // Compute workspace size. - gesvdx( - /* jobu = */ job_u, - /* jobvt = */ job_vt, - /* range = */ range, + gesdd( + /* jobz = */ jobz, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, /* a = */ nullptr, /* lda = */ &lda, - /* vl = */ &ignored_float, - /* vu = */ &ignored_float, - /* il = */ &ignored_int, - /* iu = */ &ignored_int, - /* ns = */ &ns, /* s = */ nullptr, /* u = */ nullptr, /* ldu = */ &ldu, @@ -136,20 +124,13 @@ void svd_impl( // Loop over matrices. for (int i = 0; i < num_matrices; i++) { - gesvdx( - /* jobu = */ job_u, - /* jobvt = */ job_vt, - /* range = */ range, + gesdd( + /* jobz = */ jobz, // M and N are swapped since lapack expects column-major. /* m = */ &N, /* n = */ &M, /* a = */ in_ptr + M * N * i, /* lda = */ &lda, - /* vl = */ &ignored_float, - /* vu = */ &ignored_float, - /* il = */ &ignored_int, - /* iu = */ &ignored_int, - /* ns = */ &ns, /* s = */ s_ptr + K * i, // According to the identity above, lapack will write Vᵀᵀ as U. /* u = */ vt_ptr ? vt_ptr + N * N * i : nullptr, @@ -167,13 +148,6 @@ void svd_impl( ss << "svd_impl: sgesvdx_ failed with code " << info; throw std::runtime_error(ss.str()); } - - if (ns != K) { - std::stringstream ss; - ss << "svd_impl: expected " << K << " singular values, but " << ns - << " were computed."; - throw std::runtime_error(ss.str()); - } } }); encoder.add_temporary(in); From 65d0d402321f8cacaf0143591d77f4025b44b6b9 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 20 Aug 2025 09:29:28 +0900 Subject: [PATCH 18/20] Split cuDNN helpers into a separate header (#2491) * Add RAII managed CudaGraph class * Implement forward rms_norm with cuDNN * Revert back to old rms norm kernel --- mlx/backend/cuda/CMakeLists.txt | 1 + mlx/backend/cuda/conv.cpp | 248 +++--------------------------- mlx/backend/cuda/cudnn_utils.cpp | 252 +++++++++++++++++++++++++++++++ mlx/backend/cuda/cudnn_utils.h | 164 ++++++++++++++++++++ mlx/backend/cuda/device.cpp | 14 +- mlx/backend/cuda/device.h | 4 +- mlx/backend/cuda/utils.cpp | 50 +++--- mlx/backend/cuda/utils.h | 96 +++++++----- 8 files changed, 527 insertions(+), 302 deletions(-) create mode 100644 mlx/backend/cuda/cudnn_utils.cpp create mode 100644 mlx/backend/cuda/cudnn_utils.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 38708be5e..c529af1d2 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -17,6 +17,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 29d4803a8..4c8016e9d 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -1,15 +1,11 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/cudnn_utils.h" #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" -#include "mlx/dtype_utils.h" #include "mlx/primitives.h" -#include -#include -#include #include #include @@ -18,9 +14,6 @@ namespace mlx::core { namespace { -// Not all engines support it so can not use this API now. -#define MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API 0 - // Alias for better readability. #define CONV_FORWARD CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR #define CONV_BACKWARD_INPUT \ @@ -52,198 +45,6 @@ auto& conv_cache() { return cache; } -template -inline SmallVector convert_vector(const Vec& vec) { - return SmallVector(vec.begin(), vec.end()); -} - -template class Vec> -inline std::array fixed_vector(const Vec& vec) { - if (vec.size() > MAX_NDIM) { - throw std::runtime_error( - fmt::format("ndim can not be larger than {}.", MAX_NDIM)); - } - std::array result = {}; - std::copy_n(vec.begin(), vec.size(), result.begin()); - return result; -} - -auto nhwc_to_nchw(const array& x) { - auto shape = convert_vector(x.shape()); - shape.insert(shape.begin() + 1, shape.back()); - shape.erase(shape.end() - 1); - auto strides = convert_vector(x.strides()); - strides.insert(strides.begin() + 1, strides.back()); - strides.erase(strides.end() - 1); - return std::make_tuple(std::move(shape), std::move(strides)); -} - -inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { - switch (dtype) { - case int8: - return CUDNN_DATA_INT8; - case int32: - return CUDNN_DATA_INT32; - case uint8: - return CUDNN_DATA_UINT8; - case float16: - return CUDNN_DATA_HALF; - case bfloat16: - return CUDNN_DATA_BFLOAT16; - case float32: - return CUDNN_DATA_FLOAT; - case float64: - return CUDNN_DATA_DOUBLE; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in Convolution: {}.", dtype_to_string(dtype))); - } -} - -inline uint8_t get_alignment(const array& x) { - uint8_t alignment = 1; - uintptr_t address = reinterpret_cast(x.data()); - for (; alignment < 32; alignment *= 2) { - if (address % (alignment * 2)) { - return alignment; - } - } - return alignment; -} - -inline cudnn_frontend::Tensor build_tensor(int64_t id, const array& x) { - auto [shape, strides] = nhwc_to_nchw(x); - return cudnn_frontend::TensorBuilder() - .setDim(shape.size(), shape.data()) - .setStrides(strides.size(), strides.data()) - .setId(id) - .setAlignment(get_alignment(x)) - .setDataType(dtype_to_cudnn_type(x.dtype())) - .build(); -} - -cudnn_frontend::EngineConfigList get_engine_configs( - cudnnBackendDescriptorType_t backend_type, - Dtype dtype, - cudnn_frontend::OperationGraph& op_graph, - bool use_fallback = false) { - cudnn_frontend::GeneratorSource source; - if (use_fallback) { - source = [&backend_type](cudnn_frontend::OperationGraph& op_graph) { - auto fallback = cudnn_frontend::EngineFallbackListBuilder() - .setOperationGraph(op_graph) - .setOperation(backend_type) - .build(); - return fallback.getFallbackList(); - }; - } else { - source = [](cudnn_frontend::OperationGraph& op_graph) { - auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() - .setOperationGraph(op_graph) - .setHeurMode(CUDNN_HEUR_MODE_A) - .build(); - return heuristics.getEngineConfig(heuristics.getEngineConfigCount()); - }; - } - - cudnn_frontend::EngineConfigGenerator generator(1, &source); - auto configs = generator.generate_engine_config(op_graph); - - cudnn_frontend::EngineConfigList filtered_configs; - cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) { - if (cudnn_frontend::hasNumericalNote< - CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) { - return true; - } - if (cudnn_frontend::hasNumericalNote(c) && - dtype == float32 && !env::enable_tf32()) { - return true; - } - return false; - }); - return filtered_configs; -} - -bool execute_plan( - cu::CommandEncoder& encoder, - cudnn_frontend::ExecutionPlan& plan, - array& x, - array& w, - array& y) { - int workspace_size = plan.getWorkspaceSize(); - array workspace(allocator::malloc(workspace_size), {workspace_size}, uint8); - - int64_t uids[3] = {'x', 'w', 'y'}; - void* data_ptrs[3] = { - x.data(), - w.data(), - y.data(), - }; - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace.data()) - .setDataPointers(3, data_ptrs) - .setUids(3, uids) - .build(); - - auto handle = encoder.device().cudnn_handle(); - cudnnSetStream(handle, encoder.stream()); - -#if CUDNN_VERSION >= 90500 && MLX_USE_CUDNN_NATIVE_CUDA_GRAPH_API - cudaGraph_t graph; - cudaGraphCreate(&graph, 0); - std::unique_ptr graph_freer( - &graph, [](cudaGraph_t* p) { cudaGraphDestroy(*p); }); - if (cudnnBackendPopulateCudaGraph( - handle, plan.get_raw_desc(), variantPack.get_raw_desc(), graph) != - CUDNN_STATUS_SUCCESS) { - return false; - } - encoder.add_graph_node(graph); -#else - auto capture = encoder.capture_context(); - if (cudnnBackendExecute( - handle, plan.get_raw_desc(), variantPack.get_raw_desc()) != - CUDNN_STATUS_SUCCESS) { - // Discard the captured graph when failed. - capture.discard = true; - return false; - } -#endif - - encoder.add_temporary(workspace); - return true; -} - -bool try_engines( - cu::CommandEncoder& encoder, - const ConvCacheKey& cache_key, - cudnnBackendDescriptorType_t backend_type, - cudnn_frontend::EngineConfigList& configs, - const std::string& op_graph_tag, - array& x, - array& w, - array& y) { - for (auto& config : configs) { - try { - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(encoder.device().cudnn_handle()) - .setEngineConfig(config, op_graph_tag) - .build(); - if (execute_plan(encoder, plan, x, w, y)) { - conv_cache().emplace( - cache_key, std::make_pair(backend_type, std::move(plan))); - return true; - } - } catch (cudnn_frontend::cudnnException& error) { - if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) { - throw; - } - } - } - return false; -} - auto get_conv_op_settings( cudnnBackendDescriptorType_t backend_type, array& x, @@ -288,7 +89,7 @@ auto get_conv_op_settings( } } -std::optional build_op_graph( +std::optional build_conv_op_graph( cu::CommandEncoder& encoder, cudnnBackendDescriptorType_t backend_type, Dtype dtype, @@ -314,9 +115,9 @@ std::optional build_op_graph( .build(); auto op = cudnn_frontend::OperationBuilder(backend_type) - .setxDesc(build_tensor('x', x)) - .setwDesc(build_tensor('w', w)) - .setyDesc(build_tensor('y', y)) + .setxDesc(build_cudnn_tensor_nchw('x', x)) + .setwDesc(build_cudnn_tensor_nchw('w', w)) + .setyDesc(build_cudnn_tensor_nchw('y', y)) .setcDesc(conv_desc) .build(); @@ -478,12 +279,12 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { ConvCacheKey cache_key{ encoder.device().cuda_device(), dtype_to_cudnn_type(dtype), - fixed_vector(in.shape()), - fixed_vector(wt.shape()), - fixed_vector(kernel_strides_), - fixed_vector(padding_lo_), - fixed_vector(padding_hi_), - fixed_vector(kernel_dilation_), + vector_key(in.shape()), + vector_key(wt.shape()), + vector_key(kernel_strides_), + vector_key(padding_lo_), + vector_key(padding_hi_), + vector_key(kernel_dilation_), groups_, flip_, get_alignment(in), @@ -495,7 +296,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { prepare_args(encoder, backend_type, in, wt, out, groups_, s); register_args(encoder, backend_type, in, wt, out, out_); auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (!execute_plan(encoder, plan, x, w, y)) { + if (!encode_cudnn_plan(encoder, plan, {'x', 'w', 'y'}, x, w, y)) { throw std::runtime_error("[conv] Cached plan failed to execute."); } return; @@ -537,7 +338,7 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { padding_hi_, kernel_dilation_, input_dilation_); - op_graph = build_op_graph( + op_graph = build_conv_op_graph( encoder, try_backend, dtype, @@ -560,22 +361,21 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { throw std::runtime_error("[conv] Can not build op graph."); } - // Get ready to execute the graph. + // Setup inputs and outputs. register_args(encoder, backend_type, in, wt, out, out_); - // Try to run plans based on heuristics. - auto configs = get_engine_configs(backend_type, dtype, *op_graph); - auto tag = op_graph->getTag(); + // Find a plan for the graph and execute it. + auto plan = find_cudnn_plan_from_op_graph( + encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); + if (!plan) { + throw std::runtime_error("[conv] Unable to find an execution plan."); + } auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) { - return; + if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { + throw std::runtime_error("[conv] Failed to run execution plan."); } - // Then try fallback plans. - configs = get_engine_configs(backend_type, dtype, *op_graph); - if (try_engines(encoder, cache_key, backend_type, configs, tag, x, w, y)) { - return; - } - throw std::runtime_error("[conv] Unable to find a working engine."); + conv_cache().emplace( + cache_key, std::make_pair(backend_type, std::move(*plan))); } } // namespace mlx::core diff --git a/mlx/backend/cuda/cudnn_utils.cpp b/mlx/backend/cuda/cudnn_utils.cpp new file mode 100644 index 000000000..76bcc5b0b --- /dev/null +++ b/mlx/backend/cuda/cudnn_utils.cpp @@ -0,0 +1,252 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/cudnn_utils.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +namespace { + +// Create a cudnn tensor descriptor. +template +inline cudnn_frontend::Tensor build_cudnn_tensor( + int64_t id, + const array& x, + const Vec& shape, + const Vec& strides) { + return cudnn_frontend::TensorBuilder() + .setDim(shape.size(), shape.data()) + .setStrides(strides.size(), strides.data()) + .setId(id) + .setAlignment(get_alignment(x)) + .setDataType(dtype_to_cudnn_type(x.dtype())) + .build(); +} + +// Return the shape and strides after transposing from NHWC to NCHW. +auto nhwc_to_nchw(SmallVector shape, SmallVector strides) { + assert(shape.size() >= 3); + shape.insert(shape.begin() + 1, shape.back()); + shape.erase(shape.end() - 1); + strides.insert(strides.begin() + 1, strides.back()); + strides.erase(strides.end() - 1); + return std::make_tuple(std::move(shape), std::move(strides)); +} + +auto nhwc_to_nchw(const array& x) { + return nhwc_to_nchw(convert_vector(x.shape()), x.strides()); +} + +// Return available engines for a |op_graph|. +cudnn_frontend::EngineConfigList get_cudnn_engine_configs( + cudnnBackendDescriptorType_t backend_type, + Dtype dtype, + cudnn_frontend::OperationGraph& op_graph, + bool use_fallback = true) { + SmallVector sources; + sources.push_back([](auto& op_graph) { + auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() + .setOperationGraph(op_graph) + .setHeurMode(CUDNN_HEUR_MODE_A) + .build(); + return heuristics.getEngineConfig(heuristics.getEngineConfigCount()); + }); + if (use_fallback) { + sources.push_back([&backend_type](auto& op_graph) { + auto fallback = cudnn_frontend::EngineFallbackListBuilder() + .setOperationGraph(op_graph) + .setOperation(backend_type) + .build(); + return fallback.getFallbackList(); + }); + } + + auto configs = + cudnn_frontend::EngineConfigGenerator(sources.size(), sources.data()) + .generate_engine_config(op_graph); + + cudnn_frontend::EngineConfigList filtered_configs; + cudnn_frontend::filter(configs, filtered_configs, [dtype](auto c) { + if (cudnn_frontend::hasNumericalNote< + CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) { + return true; + } + if (cudnn_frontend::hasNumericalNote(c) && + dtype == float32 && !env::enable_tf32()) { + return true; + } + return false; + }); + return filtered_configs; +} + +// Take |engine_configs| and |op_graph| and find a working execution plans +// from them. +std::optional +find_cudnn_plan_from_engine_configs( + cudnnHandle_t handle, + const cudnn_frontend::EngineConfigList& engine_configs, + const cudnn_frontend::OperationGraph& op_graph) { + auto op_graph_tag = op_graph.getTag(); + for (const auto& config : engine_configs) { + try { + return cudnn_frontend::ExecutionPlanBuilder() + .setHandle(handle) + .setEngineConfig(config, op_graph_tag) + .build(); + } catch (cudnn_frontend::cudnnException& error) { + if (error.getCudnnStatus() != CUDNN_STATUS_NOT_SUPPORTED) { + throw; + } + } + } + return std::nullopt; +} + +// Prepare workspace and args to execute plan. +template +bool prepare_cudnn_plan( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + int num_args, + const int64_t* uids, + void** data_ptrs, + F&& execute) { + int workspace_size = plan.getWorkspaceSize(); + array workspace( + workspace_size > 0 ? allocator::malloc(workspace_size) + : allocator::Buffer(nullptr), + {workspace_size}, + uint8); + + auto args = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(workspace.data()) + .setDataPointers(num_args, data_ptrs) + .setUids(num_args, uids) + .build(); + + auto handle = encoder.device().cudnn_handle(); + cudnnSetStream(handle, encoder.stream()); + + if (!execute(handle, plan.get_raw_desc(), args.get_raw_desc())) { + return false; + } + + encoder.add_temporary(workspace); + return true; +} + +} // namespace + +cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x) { + auto shape = convert_vector(x.shape()); + return build_cudnn_tensor(id, x, shape, x.strides()); +} + +cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x) { + auto [shape, strides] = nhwc_to_nchw(x); + return build_cudnn_tensor(id, x, shape, strides); +} + +cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x) { + if (x.ndim() == 0) { + SmallVector scalar_dims = {1, 1, 1, 1}; + return build_cudnn_tensor(id, x, scalar_dims, scalar_dims); + } + if (x.ndim() == 1) { + int64_t s = x.shape(0); + SmallVector shape = {1, x.shape(0), 1, 1}; + SmallVector strides = {s, 1, s, s}; + return build_cudnn_tensor(id, x, shape, strides); + } + if (x.ndim() == 2) { + int64_t s = x.strides(0); + SmallVector shape = {x.shape(0), x.shape(1), 1, 1}; + SmallVector strides = {s, x.strides(1), s, s}; + return build_cudnn_tensor(id, x, shape, strides); + } + if (x.ndim() == 3 || x.ndim() == 4) { + return build_cudnn_tensor_nchw(id, x); + } + throw std::runtime_error( + fmt::format("Unsupported array with {} dims.", x.ndim())); +} + +cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype) { + SmallVector scalar_dims = {1, 1, 1, 1}; + return cudnn_frontend::TensorBuilder() + .setDim(scalar_dims.size(), scalar_dims.data()) + .setStrides(scalar_dims.size(), scalar_dims.data()) + .setId(id) + .setAlignment(16) + .setDataType(dtype_to_cudnn_type(dtype)) + .setByValue(true) + .build(); +} + +std::optional find_cudnn_plan_from_op_graph( + cudnnHandle_t handle, + cudnnBackendDescriptorType_t backend_type, + Dtype dtype, + cudnn_frontend::OperationGraph& op_graph) { + auto engine_configs = get_cudnn_engine_configs(backend_type, dtype, op_graph); + return find_cudnn_plan_from_engine_configs(handle, engine_configs, op_graph); +} + +bool encode_cudnn_plan_with_capturing( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + int num_args, + const int64_t* uids, + void** data_ptrs) { + return prepare_cudnn_plan( + encoder, + plan, + num_args, + uids, + data_ptrs, + [&](auto handle, auto plan, auto args) { + auto capture = encoder.capture_context(); + if (cudnnBackendExecute(handle, plan, args) != CUDNN_STATUS_SUCCESS) { + // Discard the captured graph when failed. + capture.discard = true; + return false; + } + return true; + }); +} + +#if CUDNN_VERSION >= 90500 +bool encode_cudnn_plan_with_graph_api( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + CudaGraph& graph, + int num_args, + const int64_t* uids, + void** data_ptrs) { + return prepare_cudnn_plan( + encoder, + plan, + num_args, + uids, + data_ptrs, + [&](auto handle, auto plan, auto args) { + if (!graph) { + graph = CudaGraph(encoder.device()); + if (cudnnBackendPopulateCudaGraph(handle, plan, args, graph) != + CUDNN_STATUS_SUCCESS) { + return false; + } + } else { + if (cudnnBackendUpdateCudaGraph(handle, plan, args, graph) != + CUDNN_STATUS_SUCCESS) { + return false; + } + } + encoder.add_graph_node(graph); + return true; + }); +} +#endif + +} // namespace mlx::core diff --git a/mlx/backend/cuda/cudnn_utils.h b/mlx/backend/cuda/cudnn_utils.h new file mode 100644 index 000000000..c35c5cac9 --- /dev/null +++ b/mlx/backend/cuda/cudnn_utils.h @@ -0,0 +1,164 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/dtype_utils.h" + +#include +#include +#include + +#include +#include + +namespace mlx::core { + +namespace cu { +class CommandEncoder; +} + +// Return pointer alignment of |x|'s data. +inline uint8_t get_alignment(const array& x) { + uint8_t alignment = 1; + uintptr_t address = reinterpret_cast(x.data()); + for (; alignment < 32; alignment *= 2) { + if (address % (alignment * 2)) { + return alignment; + } + } + return alignment; +} + +// Convert the type of elements in |vec| to |T|. +template +inline SmallVector convert_vector(const Vec& vec) { + return SmallVector(vec.begin(), vec.end()); +} + +// Return an array that can be used as map key for |vec| with size <= MAX_NDIM. +// +// There are 2 differences from the const_param util from kernel_utils.cuh: +// 1. The rest of array is filled with 0. +// 2. This util can be used in .cpp files. +template class Vec> +inline std::array vector_key(const Vec& vec) { + if (vec.size() > MAX_NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", MAX_NDIM)); + } + std::array result = {}; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + +// Helpers used by get_data_ptrs to get pointers. +inline void* get_data_ptr(const array& arr) { + return const_cast(arr.data()); +} + +template >> +inline void* get_data_ptr(T& scalar) { + return &scalar; +} + +// Return an array filled with data pointers of args. +template +inline std::array get_data_ptrs(Args&... args) { + return {get_data_ptr(args)...}; +} + +// Map dtype to cudnn data type. +inline cudnnDataType_t dtype_to_cudnn_type(Dtype dtype) { + switch (dtype) { + case int8: + return CUDNN_DATA_INT8; + case int32: + return CUDNN_DATA_INT32; + case uint8: + return CUDNN_DATA_UINT8; + case float16: + return CUDNN_DATA_HALF; + case bfloat16: + return CUDNN_DATA_BFLOAT16; + case float32: + return CUDNN_DATA_FLOAT; + case float64: + return CUDNN_DATA_DOUBLE; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in Convolution: {}.", dtype_to_string(dtype))); + } +} + +// Create a tensor descriptor from |x|. +cudnn_frontend::Tensor build_cudnn_tensor(int64_t id, const array& x); + +// Create a tensor descriptor from |x|, and transpose from NHWC to NCHW. +cudnn_frontend::Tensor build_cudnn_tensor_nchw(int64_t id, const array& x); + +// Create a tensor descriptor from |x|, make sure it is 4D, and transpose it +// from NHWC to NCHW. +cudnn_frontend::Tensor build_cudnn_tensor_4d_nchw(int64_t id, const array& x); + +// Create a 4D scalar tensor descriptor, which is passed by value. +cudnn_frontend::Tensor build_cudnn_scalar_4d(int64_t id, Dtype dtype); + +// Find a working plan for |op_graph|. +std::optional find_cudnn_plan_from_op_graph( + cudnnHandle_t handle, + cudnnBackendDescriptorType_t backend_type, + Dtype dtype, + cudnn_frontend::OperationGraph& op_graph); + +// Encode the plan to command buffer by capturing. +bool encode_cudnn_plan_with_capturing( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + int num_args, + const int64_t* uids, + void** data_ptrs); + +#if CUDNN_VERSION >= 90500 +// Encode the plan to command buffer by using native graph api of cudnn. If the +// |graph| is empty it will be populated, otherwise it will be updated. +bool encode_cudnn_plan_with_graph_api( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + CudaGraph& graph, + int num_args, + const int64_t* uids, + void** data_ptrs); +#endif + +// Helpers to make calls like encode_cudnn_plan(..., {'x', 'y', 'z'}, x, y, z). +template +bool encode_cudnn_plan( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + std::initializer_list uids, + Args&... args) { + assert(uids.size() == sizeof...(args)); + auto data_ptrs = get_data_ptrs(args...); + return encode_cudnn_plan_with_capturing( + encoder, plan, uids.size(), uids.begin(), data_ptrs.data()); +} + +#if CUDNN_VERSION >= 90500 +template +bool encode_cudnn_plan( + cu::CommandEncoder& encoder, + cudnn_frontend::ExecutionPlan& plan, + CudaGraph& graph, + std::initializer_list uids, + Args&... args) { + assert(uids.size() == sizeof...(args)); + auto data_ptrs = get_data_ptrs(args...); + return encode_cudnn_plan_with_graph_api( + encoder, plan, graph, uids.size(), uids.begin(), data_ptrs.data()); +} +#endif + +} // namespace mlx::core diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 96b07502f..371ae020c 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -91,9 +91,7 @@ CommandEncoder::CaptureContext::CaptureContext(CommandEncoder& enc) : enc(enc) { } CommandEncoder::CaptureContext::~CaptureContext() { - CHECK_CUDA_ERROR(cudaStreamEndCapture(enc.stream(), &graph)); - std::unique_ptr graph_freer( - &graph, [](cudaGraph_t* p) { CHECK_CUDA_ERROR(cudaGraphDestroy(*p)); }); + graph.end_capture(enc.stream()); if (discard) { return; } @@ -185,9 +183,10 @@ void CommandEncoder::insert_graph_dependencies(std::vector nodes) { } CommandEncoder::CommandEncoder(Device& d) - : device_(d), stream_(d), graph_cache_(cuda_graph_cache_size()) { - CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); -} + : device_(d), + stream_(d), + graph_(d), + graph_cache_(cuda_graph_cache_size()) {} void CommandEncoder::add_completed_handler(std::function task) { worker_.add_task(std::move(task)); @@ -311,8 +310,7 @@ void CommandEncoder::commit() { to_nodes_.clear(); graph_key_.clear(); node_map_.clear(); - CHECK_CUDA_ERROR(cudaGraphDestroy(graph_)); - CHECK_CUDA_ERROR(cudaGraphCreate(&graph_, 0)); + graph_ = CudaGraph(device_); } // Put completion handlers in a batch. diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 5eb7fd4c1..7b0ff5629 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -21,7 +21,7 @@ class CommandEncoder { struct CaptureContext { CaptureContext(CommandEncoder& enc); ~CaptureContext(); - cudaGraph_t graph; + CudaGraph graph; CommandEncoder& enc; bool discard{false}; }; @@ -115,7 +115,7 @@ class CommandEncoder { Device& device_; CudaStream stream_; - cudaGraph_t graph_; + CudaGraph graph_; Worker worker_; char node_count_{0}; char graph_node_count_{0}; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 88940a234..09894d4ca 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -8,36 +8,6 @@ namespace mlx::core { -CudaStream::CudaStream(cu::Device& device) { - device.make_current(); - CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); -} - -CudaStream::~CudaStream() { - CHECK_CUDA_ERROR(cudaStreamDestroy(stream_)); -} - -CudaGraphExec::CudaGraphExec(cudaGraphExec_t handle) : handle_(handle) {} - -CudaGraphExec::CudaGraphExec(CudaGraphExec&& other) : handle_(other.handle_) { - other.handle_ = nullptr; -}; - -CudaGraphExec::~CudaGraphExec() { - reset(); -} - -void CudaGraphExec::instantiate(cudaGraph_t graph) { - CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); -} - -void CudaGraphExec::reset() { - if (handle_ != nullptr) { - CHECK_CUDA_ERROR(cudaGraphExecDestroy(handle_)); - handle_ = nullptr; - } -} - void check_cublas_error(const char* name, cublasStatus_t err) { if (err != CUBLAS_STATUS_SUCCESS) { // TODO: Use cublasGetStatusString when it is widely available. @@ -96,4 +66,24 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { } } +CudaGraph::CudaGraph(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaGraphCreate(&handle_, 0)); +} + +void CudaGraph::end_capture(cudaStream_t stream) { + assert(handle_ == nullptr); + CHECK_CUDA_ERROR(cudaStreamEndCapture(stream, &handle_)); +} + +void CudaGraphExec::instantiate(cudaGraph_t graph) { + assert(handle_ == nullptr); + CHECK_CUDA_ERROR(cudaGraphInstantiate(&handle_, graph, nullptr, nullptr, 0)); +} + +CudaStream::CudaStream(cu::Device& device) { + device.make_current(); + CHECK_CUDA_ERROR(cudaStreamCreateWithFlags(&handle_, cudaStreamNonBlocking)); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/utils.h b/mlx/backend/cuda/utils.h index 555e15065..e811d5e6c 100644 --- a/mlx/backend/cuda/utils.h +++ b/mlx/backend/cuda/utils.h @@ -16,44 +16,6 @@ class Device; struct Dtype; -// Cuda stream managed with RAII. -class CudaStream { - public: - explicit CudaStream(cu::Device& device); - ~CudaStream(); - - CudaStream(const CudaStream&) = delete; - CudaStream& operator=(const CudaStream&) = delete; - - operator cudaStream_t() const { - return stream_; - } - - private: - cudaStream_t stream_; -}; - -// Move-able RAII handle of cudaGraphExec_t. -class CudaGraphExec { - public: - CudaGraphExec(cudaGraphExec_t handle = nullptr); - CudaGraphExec(CudaGraphExec&& other); - ~CudaGraphExec(); - - CudaGraphExec(const CudaGraphExec&) = delete; - CudaGraphExec& operator=(const CudaGraphExec&) = delete; - - void instantiate(cudaGraph_t graph); - void reset(); - - operator cudaGraphExec_t() const { - return handle_; - } - - private: - cudaGraphExec_t handle_; -}; - // Throw exception if the cuda API does not succeed. void check_cublas_error(const char* name, cublasStatus_t err); void check_cuda_error(const char* name, cudaError_t err); @@ -66,4 +28,62 @@ void check_cuda_error(const char* name, CUresult err); // Convert Dtype to CUDA C++ types. const char* dtype_to_cuda_type(const Dtype& dtype); +// Base class for RAII managed CUDA resources. +template +class CudaHandle { + public: + CudaHandle(Handle handle = nullptr) : handle_(handle) {} + + CudaHandle(CudaHandle&& other) : handle_(other.handle_) { + assert(this != &other); + other.handle_ = nullptr; + } + + ~CudaHandle() { + reset(); + } + + CudaHandle(const CudaHandle&) = delete; + CudaHandle& operator=(const CudaHandle&) = delete; + + CudaHandle& operator=(CudaHandle&& other) { + assert(this != &other); + reset(); + std::swap(handle_, other.handle_); + return *this; + } + + void reset() { + if (handle_ != nullptr) { + CHECK_CUDA_ERROR(Destroy(handle_)); + handle_ = nullptr; + } + } + + operator Handle() const { + return handle_; + } + + protected: + Handle handle_; +}; + +// Wrappers of CUDA resources. +class CudaGraph : public CudaHandle { + public: + using CudaHandle::CudaHandle; + explicit CudaGraph(cu::Device& device); + void end_capture(cudaStream_t stream); +}; + +class CudaGraphExec : public CudaHandle { + public: + void instantiate(cudaGraph_t graph); +}; + +class CudaStream : public CudaHandle { + public: + explicit CudaStream(cu::Device& device); +}; + } // namespace mlx::core From ac85ddfdb70a0e80f2281f42146d3a89999b1609 Mon Sep 17 00:00:00 2001 From: Cheng Date: Wed, 20 Aug 2025 10:06:22 +0900 Subject: [PATCH 19/20] [CUDA] Add GEMM-based fallback convolution kernels (#2511) * Add gemm_conv * Add gemm_grouped_conv --- mlx/backend/cuda/CMakeLists.txt | 2 + mlx/backend/cuda/conv.cpp | 85 +++++--- mlx/backend/cuda/conv/conv.h | 126 +++++++++++ mlx/backend/cuda/conv/gemm_conv.cu | 217 +++++++++++++++++++ mlx/backend/cuda/conv/gemm_grouped_conv.cu | 231 +++++++++++++++++++++ mlx/backend/cuda/gemms/cublas_gemm.cpp | 19 ++ mlx/backend/cuda/gemms/cublas_gemm.h | 11 + python/tests/cuda_skip.py | 8 - 8 files changed, 667 insertions(+), 32 deletions(-) create mode 100644 mlx/backend/cuda/conv/conv.h create mode 100644 mlx/backend/cuda/conv/gemm_conv.cu create mode 100644 mlx/backend/cuda/conv/gemm_grouped_conv.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index c529af1d2..994307284 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -16,6 +16,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 4c8016e9d..63188fbc8 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -1,5 +1,6 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/conv/conv.h" #include "mlx/backend/cuda/cudnn_utils.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/lru_cache.h" @@ -21,6 +22,9 @@ namespace { #define CONV_BACKWARD_WEIGHT \ CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR +// Custom placeholder representing fallback kernel. +#define CONV_FALLBACK static_cast(-1) + struct ConvCacheKey { int device_id; cudnnDataType_t cudnn_dtype; @@ -40,7 +44,9 @@ struct ConvCacheKey { auto& conv_cache() { static LRUBytesKeyCache< ConvCacheKey, - std::pair> + std::pair< + cudnnBackendDescriptorType_t, + std::optional>> cache(/* capacity */ 128); return cache; } @@ -292,12 +298,29 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { get_alignment(out)}; if (auto it = conv_cache().find(cache_key); it != conv_cache().end()) { auto& [backend_type, plan] = it->second; - std::tie(in, wt, out) = - prepare_args(encoder, backend_type, in, wt, out, groups_, s); - register_args(encoder, backend_type, in, wt, out, out_); - auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (!encode_cudnn_plan(encoder, plan, {'x', 'w', 'y'}, x, w, y)) { - throw std::runtime_error("[conv] Cached plan failed to execute."); + if (plan) { + // Run cached plan. + std::tie(in, wt, out) = + prepare_args(encoder, backend_type, in, wt, out, groups_, s); + register_args(encoder, backend_type, in, wt, out, out_); + auto [x, w, y] = dispatch_args(backend_type, in, wt, out); + if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { + throw std::runtime_error("[conv] Cached plan failed to execute."); + } + } else { + // Run fallback kernel. + gemm_conv( + encoder, + in, + wt, + out, + kernel_strides_, + padding_lo_, + kernel_dilation_, + input_dilation_, + groups_, + flip_, + s); } return; } @@ -357,25 +380,39 @@ void Convolution::eval_gpu(const std::vector& inputs, array& out_) { break; } } - if (!op_graph) { - throw std::runtime_error("[conv] Can not build op graph."); + + if (op_graph) { + // Setup inputs and outputs. + register_args(encoder, backend_type, in, wt, out, out_); + + // Find a plan for the graph and execute it. + auto plan = find_cudnn_plan_from_op_graph( + encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); + if (!plan) { + throw std::runtime_error("[conv] Unable to find an execution plan."); + } + auto [x, w, y] = dispatch_args(backend_type, in, wt, out); + if (encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { + conv_cache().emplace( + cache_key, std::make_pair(backend_type, std::move(*plan))); + return; + } } - // Setup inputs and outputs. - register_args(encoder, backend_type, in, wt, out, out_); - - // Find a plan for the graph and execute it. - auto plan = find_cudnn_plan_from_op_graph( - encoder.device().cudnn_handle(), backend_type, dtype, *op_graph); - if (!plan) { - throw std::runtime_error("[conv] Unable to find an execution plan."); - } - auto [x, w, y] = dispatch_args(backend_type, in, wt, out); - if (!encode_cudnn_plan(encoder, *plan, {'x', 'w', 'y'}, x, w, y)) { - throw std::runtime_error("[conv] Failed to run execution plan."); - } - conv_cache().emplace( - cache_key, std::make_pair(backend_type, std::move(*plan))); + // Use fallback kernel for settings not supported by cuDNN. + gemm_conv( + encoder, + in, + wt, + out, + kernel_strides_, + padding_lo_, + kernel_dilation_, + input_dilation_, + groups_, + flip_, + s); + conv_cache().emplace(cache_key, std::make_pair(CONV_FALLBACK, std::nullopt)); } } // namespace mlx::core diff --git a/mlx/backend/cuda/conv/conv.h b/mlx/backend/cuda/conv/conv.h new file mode 100644 index 000000000..62dc9343e --- /dev/null +++ b/mlx/backend/cuda/conv/conv.h @@ -0,0 +1,126 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { + +template +struct ConvParams { + int N; // Batch size + int C; // In channels + int O; // Out channels + int strides[NDIM]; + int padding[NDIM]; + int kernel_dilation[NDIM]; + int input_dilation[NDIM]; + int groups; + bool flip; + int in_spatial_dims[NDIM]; + int wt_spatial_dims[NDIM]; + int out_spatial_dims[NDIM]; + int64_t in_strides[NDIM + 2]; + + ConvParams( + const array& in, + const array& wt, + const array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip) + : N(in.shape(0)), + C(in.shape(-1)), + O(wt.shape(0)), + groups(groups), + flip(flip) { + std::copy_n(strides.begin(), NDIM, this->strides); + std::copy_n(padding.begin(), NDIM, this->padding); + std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation); + std::copy_n(input_dilation.begin(), NDIM, this->input_dilation); + std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims); + std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims); + std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims); + std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides); + } +}; + +void gemm_grouped_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s); + +void gemm_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s); + +inline void gemm_conv( + cu::CommandEncoder& encoder, + array in, + array wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + + if (groups == 1) { + gemm_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + } else { + gemm_grouped_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/conv/gemm_conv.cu b/mlx/backend/cuda/conv/gemm_conv.cu new file mode 100644 index 000000000..11a78a7ab --- /dev/null +++ b/mlx/backend/cuda/conv/gemm_conv.cu @@ -0,0 +1,217 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/conv/conv.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void naive_unfold_nd( + const T* in, + T* out, + int filter_size, + int out_pixels, + const __grid_constant__ ConvParams params) { + auto block = cg::this_thread_block(); + auto tid = block.group_index(); + auto lid = block.thread_index(); + + int index_batch = tid.z / out_pixels; // [0, N) + int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out) + int index_wt_spatial = + tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt) + + if (index_wt_spatial >= filter_size / params.C) { + return; + } + + in += tid.y; // [0, C) + out += tid.z * filter_size + index_wt_spatial * params.C + tid.y; + + bool valid = index_batch < params.N; + + // Get the coordinates in input. + int index_in[NDIM] = {}; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int index_out = index_out_spatial % params.out_spatial_dims[i]; + int index_wt = index_wt_spatial % params.wt_spatial_dims[i]; + + if (params.flip) { + index_wt = params.wt_spatial_dims[i] - index_wt - 1; + } + + int index = index_out * params.strides[i] - params.padding[i] + + index_wt * params.kernel_dilation[i]; + int index_max = + 1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1); + + valid &= (index >= 0) && (index < index_max) && + (index % params.input_dilation[i] == 0); + + index_in[i] = index / params.input_dilation[i]; + + index_out_spatial /= params.out_spatial_dims[i]; + index_wt_spatial /= params.wt_spatial_dims[i]; + } + + if (valid) { + int in_offset = index_batch * params.in_strides[0]; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + in_offset += index_in[i] * params.in_strides[i + 1]; + } + *out = in[in_offset]; + } else { + *out = T{0}; + } +} + +} // namespace cu + +template +array unfold_inputs_nd( + cu::CommandEncoder& encoder, + const array& in, + int mat_M, + int mat_K, + int mat_N, + ConvParams& params) { + array unfolded({mat_M, mat_K}, in.dtype(), nullptr, {}); + unfolded.set_data(allocator::malloc(unfolded.nbytes())); + encoder.add_temporary(unfolded); + + int filter_size = params.C; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + filter_size *= params.wt_spatial_dims[i]; + } + + int out_pixels = 1; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + out_pixels *= params.out_spatial_dims[i]; + } + + int wt_spatial_size = mat_K / params.C; + dim3 block_dims; + block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024); + dim3 num_blocks; + num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x); + num_blocks.y = params.C; + num_blocks.z = mat_M; + + encoder.set_input_array(in); + encoder.set_output_array(unfolded); + dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) { + using DataType = cuda_type_t; + encoder.add_kernel_node( + cu::naive_unfold_nd, + num_blocks, + block_dims, + 0, + in.data(), + unfolded.data(), + filter_size, + out_pixels, + params); + }); + + return unfolded; +} + +template +void gemm_conv_nd( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + ConvParams& params, + Stream s) { + // Get gemm shapes. + int mat_M = out.size() / params.O; // N * H_out * W_out + int mat_K = wt.size() / params.O; // C * H_wt * W_wt + int mat_N = params.O; // O + + // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. + array in_unfolded = + unfold_inputs_nd(encoder, in, mat_M, mat_K, mat_N, params); + + // Reshape weight to (C * H_wt * W_wt, O) for gemm. + array wt_reshaped({mat_K, mat_N}, wt.dtype(), nullptr, {}); + wt_reshaped.copy_shared_buffer( + wt, + {1, mat_K}, + {false, false, /* col_contiguous */ true}, + wt.data_size()); + + // Single batch. + Shape batch_shape{1}; + Strides a_batch_strides{0}; + Strides b_batch_strides{0}; + + // Run matmul. + CublasGemm gemm( + encoder.device(), + in.dtype(), + false, // a_transposed + mat_M, // a_rows + mat_K, // a_cols + mat_K, // lda + true, // b_transposed + mat_K, // b_rows + mat_N, // b_cols + mat_K, // ldb + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + gemm.run( + encoder, + out, + in_unfolded, + wt_reshaped, + batch_shape, + a_batch_strides, + b_batch_strides); +} + +void gemm_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s) { + int conv_ndim = in.ndim() - 2; + if (conv_ndim < 1 || conv_ndim > 3) { + throw std::runtime_error( + fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim)); + } + dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) { + ConvParams params( + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + 1, // groups + flip); + gemm_conv_nd(encoder, in, wt, out, params, s); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/conv/gemm_grouped_conv.cu b/mlx/backend/cuda/conv/gemm_grouped_conv.cu new file mode 100644 index 000000000..7ceb58166 --- /dev/null +++ b/mlx/backend/cuda/conv/gemm_grouped_conv.cu @@ -0,0 +1,231 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/conv/conv.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/dtype_utils.h" + +#include + +namespace mlx::core { + +namespace cu { + +namespace cg = cooperative_groups; + +template +__global__ void naive_grouped_unfold_transpose_nd( + const T* in, + T* out, + int filter_size, + int out_pixels, + const __grid_constant__ ConvParams params) { + auto block = cg::this_thread_block(); + auto tid = block.group_index(); + auto lid = block.thread_index(); + + int index_batch = tid.z / out_pixels; // [0, N) + int index_out_spatial = tid.z % out_pixels; // [0, H_out * W_out) + int index_wt_spatial = + tid.x * block.dim_threads().x + lid.x; // [0, H_wt * W_wt) + + if (index_wt_spatial >= filter_size / params.C) { + return; + } + + in += tid.y; // [0, C) + out += tid.z * filter_size + tid.y * (filter_size / params.C); + + bool valid = index_batch < params.N; + + // Get the coordinates in input. + int index_in[NDIM] = {}; + int wt_stride = 1; +#pragma unroll + for (int i = NDIM - 1; i >= 0; --i) { + int index_out = index_out_spatial % params.out_spatial_dims[i]; + int index_wt = index_wt_spatial % params.wt_spatial_dims[i]; + out += index_wt * wt_stride; + + if (params.flip) { + index_wt = params.wt_spatial_dims[i] - index_wt - 1; + } + + int index = index_out * params.strides[i] - params.padding[i] + + index_wt * params.kernel_dilation[i]; + int index_max = + 1 + params.input_dilation[i] * (params.in_spatial_dims[i] - 1); + + valid &= (index >= 0) && (index < index_max) && + (index % params.input_dilation[i] == 0); + + index_in[i] = index / params.input_dilation[i]; + + index_out_spatial /= params.out_spatial_dims[i]; + index_wt_spatial /= params.wt_spatial_dims[i]; + wt_stride *= params.wt_spatial_dims[i]; + } + + if (valid) { + int in_offset = index_batch * params.in_strides[0]; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + in_offset += index_in[i] * params.in_strides[i + 1]; + } + *out = in[in_offset]; + } else { + *out = T{0}; + } +} + +} // namespace cu + +template +array grouped_unfold_transpose_inputs_nd( + cu::CommandEncoder& encoder, + const array& in, + int mat_M, + int mat_K, + int mat_N, + ConvParams& params) { + array unfolded({mat_M, mat_K * params.groups}, in.dtype(), nullptr, {}); + unfolded.set_data(allocator::malloc(unfolded.nbytes())); + encoder.add_temporary(unfolded); + + int filter_size = params.C; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + filter_size *= params.wt_spatial_dims[i]; + } + + int out_pixels = 1; +#pragma unroll + for (int i = 0; i < NDIM; ++i) { + out_pixels *= params.out_spatial_dims[i]; + } + + int wt_spatial_size = (mat_K * params.groups) / params.C; + dim3 block_dims; + block_dims.x = std::min(std::max(wt_spatial_size, 32), 1024); + dim3 num_blocks; + num_blocks.x = cuda::ceil_div(wt_spatial_size, block_dims.x); + num_blocks.y = params.C; + num_blocks.z = mat_M; + + encoder.set_input_array(in); + encoder.set_output_array(unfolded); + dispatch_float_types(in.dtype(), "unfold", [&](auto type_tag) { + using DataType = cuda_type_t; + encoder.add_kernel_node( + cu::naive_grouped_unfold_transpose_nd, + num_blocks, + block_dims, + 0, + in.data(), + unfolded.data(), + filter_size, + out_pixels, + params); + }); + + return unfolded; +} + +template +void gemm_grouped_conv_nd( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + ConvParams& params, + Stream s) { + // Get gemm shapes. + int C_per_group = params.C / params.groups; + int O_per_group = params.O / params.groups; + int mat_M = out.size() / params.O; // N * H_out * W_out + int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt + int mat_N = O_per_group; // O_per_group + + // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. + array in_unfolded = grouped_unfold_transpose_inputs_nd( + encoder, in, mat_M, mat_K, mat_N, params); + + // Reshape weight to (O, C_per_group, H_wt * W_wt) for gemm. + int wt_spatial_size = (wt.size() / wt.shape(0)) / wt.shape(-1); + array wt_view( + {params.O, C_per_group, wt_spatial_size}, wt.dtype(), nullptr, {}); + wt_view.copy_shared_buffer( + wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); + array wt_reshaped = contiguous_copy_gpu(wt_view, s); + + // Batch with size of groups. + Shape batch_shape{params.groups}; + Strides a_batch_strides{mat_K}; + Strides b_batch_strides{mat_N * mat_K}; + + // Run matmul. + CublasGemm gemm( + encoder.device(), + in.dtype(), + false, // a_transposed + mat_M, // a_rows + mat_K, // a_cols + mat_K * params.groups, // lda + true, // b_transposed + mat_K, // b_rows + mat_N, // b_cols + mat_K, // ldb + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + gemm.set_out( + out.dtype(), + false, // out_transposed + mat_M, // out_rows + mat_N, // out_cols + mat_N * params.groups, // out_ld + params.groups, // batch_count + mat_N); // batch_stride + gemm.run( + encoder, + out, + in_unfolded, + wt_reshaped, + batch_shape, + a_batch_strides, + b_batch_strides); +} + +void gemm_grouped_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + int conv_ndim = in.ndim() - 2; + if (conv_ndim < 1 || conv_ndim > 3) { + throw std::runtime_error( + fmt::format("[conv] Unsupported gemm_conv for {}D conv.", conv_ndim)); + } + dispatch_1_2_3(conv_ndim, [&](auto ndim_constant) { + ConvParams params( + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip); + gemm_grouped_conv_nd(encoder, in, wt, out, params, s); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 1aeeefa38..836385dfe 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -202,6 +202,25 @@ CublasGemm::~CublasGemm() { CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); } +void CublasGemm::set_out( + Dtype dtype, + bool transposed, + uint64_t rows, + uint64_t cols, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); + out_desc_ = create_matrix_layout( + dtype_to_cublas_type(dtype), + rows, + cols, + transposed, + ld, + batch_count, + batch_stride); +} + void CublasGemm::run( cu::CommandEncoder& encoder, array& out, diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index e093351b6..1b06fb2f7 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -44,6 +44,17 @@ class CublasGemm { ~CublasGemm(); + // The output's descriptor is inferred from inputs by default, use this method + // for unusual output. + void set_out( + Dtype dtype, + bool transposed, + uint64_t rows, + uint64_t cols, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + void run( cu::CommandEncoder& encoder, array& out, diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index c635de9ad..78639da21 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -15,14 +15,6 @@ cuda_skip = { "TestOps.test_hadamard_grad_vmap", # Convolutions NYI "TestConv.test_1d_conv_with_2d", - "TestConv.test_conv_1d_groups_flipped", - "TestConv.test_conv_general_flip_grad", - "TestConv.test_torch_conv_2D", - "TestConv.test_torch_conv_depthwise", - "TestConv.test_torch_conv_general", - "TestConvTranspose.test_torch_conv_transpose_1D_grad", - "TestConvTranspose.test_torch_conv_transpose_2D_grad", - "TestConvTranspose.test_torch_conv_transpose_3D_grad", # FFTs NYI "TestFFT.test_fft", "TestFFT.test_fft_big_powers_of_two", From 512281781cc953231024d4529ad8b8ac22e5ee06 Mon Sep 17 00:00:00 2001 From: russellizadi Date: Wed, 20 Aug 2025 03:45:05 -0400 Subject: [PATCH 20/20] Remove state return from function example in compile documentation (#2518) --- docs/src/usage/compile.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/usage/compile.rst b/docs/src/usage/compile.rst index 7fe0ffd4f..ae01bb1f3 100644 --- a/docs/src/usage/compile.rst +++ b/docs/src/usage/compile.rst @@ -225,7 +225,7 @@ In some cases returning updated state can be pretty inconvenient. Hence, def fun(x, y): z = x + y state.append(z) - return mx.exp(z), state + return mx.exp(z) fun(mx.array(1.0), mx.array(2.0)) # Prints [array(3, dtype=float32)]