From c9a91805841ffd1032261a4dda8e298b642e0eec Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 20 Jun 2025 14:50:57 -0700 Subject: [PATCH] Cuda perf tuning (#2307) * perf tuning * fix adding inputs arrays in matmul / srot * format * fix --- mlx/backend/cuda/allocator.cpp | 12 +++++++- mlx/backend/cuda/copy.cu | 1 - mlx/backend/cuda/device/utils.cuh | 20 ++++++------- mlx/backend/cuda/matmul.cpp | 49 +++++++++++++++++++++++++------ mlx/backend/cuda/sort.cu | 5 ++-- 5 files changed, 63 insertions(+), 24 deletions(-) diff --git a/mlx/backend/cuda/allocator.cpp b/mlx/backend/cuda/allocator.cpp index 1d17d7df5a..6cc7145b57 100644 --- a/mlx/backend/cuda/allocator.cpp +++ b/mlx/backend/cuda/allocator.cpp @@ -3,6 +3,7 @@ #include "mlx/backend/cuda/allocator.h" #include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/worker.h" +#include "mlx/utils.h" #include #include @@ -14,9 +15,11 @@ namespace mlx::core { namespace cu { +constexpr int page_size = 16384; + CudaAllocator::CudaAllocator() : buffer_cache_( - getpagesize(), + page_size, [](CudaBuffer* buf) { return buf->size; }, [this](CudaBuffer* buf) { cuda_free(buf->data); @@ -31,7 +34,14 @@ CudaAllocator::CudaAllocator() Buffer CudaAllocator::malloc(size_t size) { // Find available buffer from cache. + auto orig_size = size; std::unique_lock lock(mutex_); + if (size < page_size) { + size = next_power_of_2(size); + } else { + size = page_size * ((size + page_size - 1) / page_size); + } + CudaBuffer* buf = buffer_cache_.reuse_from_cache(size); if (!buf) { // If we have a lot of memory pressure or are over the maximum cache size, diff --git a/mlx/backend/cuda/copy.cu b/mlx/backend/cuda/copy.cu index 817860d0ac..3218067209 100644 --- a/mlx/backend/cuda/copy.cu +++ b/mlx/backend/cuda/copy.cu @@ -24,7 +24,6 @@ void copy_gpu_inplace( auto& encoder = cu::get_command_encoder(s); encoder.set_input_array(in); encoder.set_output_array(out); - if (ctype == CopyType::Scalar || ctype == CopyType::Vector) { copy_contiguous(encoder, ctype, in, out, offset_in, offset_out); return; diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 54d5519921..6e8abdd7c0 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -155,8 +155,8 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; - a_loc += dim_idx * a_strides[i]; - b_loc += dim_idx * b_strides[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc); @@ -175,9 +175,9 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_nd( #pragma unroll for (int i = NDIM - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; - a_loc += dim_idx * a_strides[i]; - b_loc += dim_idx * b_strides[i]; - c_loc += dim_idx * c_strides[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + c_loc += dim_idx * IdxT(c_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc, c_loc); @@ -206,8 +206,8 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( IdxT b_loc = 0; for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; - a_loc += dim_idx * a_strides[i]; - b_loc += dim_idx * b_strides[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc); @@ -226,9 +226,9 @@ inline __host__ __device__ cuda::std::tuple elem_to_loc_4d( IdxT c_loc = 0; for (int i = ndim - 1; i >= 0; --i) { int dim_idx = elem % shape[i]; - a_loc += dim_idx * a_strides[i]; - b_loc += dim_idx * b_strides[i]; - c_loc += dim_idx * c_strides[i]; + a_loc += dim_idx * IdxT(a_strides[i]); + b_loc += dim_idx * IdxT(b_strides[i]); + c_loc += dim_idx * IdxT(c_strides[i]); elem /= shape[i]; } return cuda::std::make_tuple(a_loc, b_loc, c_loc); diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 5a5e6182e8..c32cecc031 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -162,11 +162,15 @@ class MatMul { } } - array workspace( - allocator::malloc(heuristic_.workspaceSize), - {static_cast(heuristic_.workspaceSize)}, - int8); - encoder.add_temporary(workspace); + void* workspace_ptr = nullptr; + if (heuristic_.workspaceSize > 0) { + array workspace( + allocator::malloc(heuristic_.workspaceSize), + {static_cast(heuristic_.workspaceSize)}, + int8); + encoder.add_temporary(workspace); + workspace_ptr = workspace.data(); + } encoder.launch_kernel([&](cudaStream_t stream) { CHECK_CUBLAS_ERROR(cublasLtMatmul( @@ -183,8 +187,8 @@ class MatMul { out, out_desc_, &heuristic_.algo, - workspace.data(), - workspace.nbytes(), + workspace_ptr, + heuristic_.workspaceSize, stream)); }); } @@ -358,9 +362,18 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { a_batch_strides.back(), b_batch_strides.back()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + auto nbatch = batch_count / batch_shape.back(); + if (nbatch == 1) { + matmul.run(encoder, out.data(), a.data(), b.data()); + return; + } + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); - for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { + for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M * N, @@ -444,10 +457,28 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { b_batch_strides.back(), c_batch_strides.back()); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + auto nbatch = batch_count / batch_shape.back(); + if (nbatch == 1) { + matmul.run( + encoder, + out.data(), + a.data(), + b.data(), + c.data(), + alpha_, + beta_); + return; + } + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); - for (size_t i = 0; i < batch_count / batch_shape.back(); ++i) { + for (size_t i = 0; i < nbatch; ++i) { matmul.run( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M * N, diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 154ca5f32c..5cbffc0f43 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -79,9 +79,6 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array out = out_; auto& encoder = cu::get_command_encoder(s); - encoder.set_input_array(in); - encoder.set_output_array(out); - if (axis < 0) { axis += in.ndim(); } @@ -106,6 +103,8 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { in.flags()); } + encoder.set_input_array(in); + encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { if constexpr (!std::is_same_v) {