diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 59db14535..084e449b0 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -53,10 +53,10 @@ target_sources( if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) target_sources( - mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu) else() target_sources( - mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp) endif() target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp index 61f12ba1d..1aeeefa38 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -7,10 +7,12 @@ #include -namespace mlx::core::cu { +namespace mlx::core { + +namespace { struct CublasPreference { - CublasPreference(Device& device) { + CublasPreference(cu::Device& device) { // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB // for Hopper+: // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace @@ -33,7 +35,7 @@ struct CublasPreference { cublasLtMatmulPreference_t pref_{nullptr}; }; -cublasLtMatmulPreference_t cublas_preference(Device& device) { +cublasLtMatmulPreference_t cublas_preference(cu::Device& device) { static CublasPreference pref(device); return pref.pref_; } @@ -52,7 +54,7 @@ cublasComputeType_t dtype_to_compute_type(Dtype dtype) { return CUBLAS_COMPUTE_64F; default: throw std::runtime_error(fmt::format( - "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); + "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); } } @@ -70,7 +72,7 @@ cudaDataType_t dtype_to_cublas_type(Dtype dtype) { return CUDA_C_32F; default: throw std::runtime_error(fmt::format( - "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); + "Unsupported dtype in CublasGemm: {}.", dtype_to_string(dtype))); } } @@ -102,8 +104,10 @@ cublasLtMatrixLayout_t create_matrix_layout( return desc; } -Matmul::Matmul( - Device& device, +} // namespace + +CublasGemm::CublasGemm( + cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, @@ -155,8 +159,8 @@ Matmul::Matmul( type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); } -Matmul::Matmul( - Device& device, +CublasGemm::CublasGemm( + cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, @@ -171,7 +175,7 @@ Matmul::Matmul( int64_t a_batch_stride, int64_t b_batch_stride, int64_t c_batch_stride) - : Matmul( + : CublasGemm( device, dtype, a_transposed, @@ -190,7 +194,7 @@ Matmul::Matmul( type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); } -Matmul::~Matmul() { +CublasGemm::~CublasGemm() { CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); @@ -198,7 +202,73 @@ Matmul::~Matmul() { CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); } -void Matmul::run_impl( +void CublasGemm::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides) { + int batch_count = out.size() / (M_ * N_); + if (batch_count / batch_shape.back() > 1) { + run_batched( + encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); + return; + } + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + execute(encoder, out.data(), a.data(), b.data(), nullptr); +} + +void CublasGemm::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, + float alpha, + float beta) { + int batch_count = out.size() / (M_ * N_); + if (batch_count / batch_shape.back() > 1) { + run_batched( + encoder, + out, + a, + b, + c, + batch_shape, + a_batch_strides, + b_batch_strides, + c_batch_strides, + alpha, + beta); + return; + } + + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + execute( + encoder, + out.data(), + a.data(), + b.data(), + c.data(), + alpha, + beta); +} + +void CublasGemm::execute( cu::CommandEncoder& encoder, void* out, const void* a, @@ -256,29 +326,4 @@ void Matmul::run_impl( encoder.stream())); } -void Matmul::run( - cu::CommandEncoder& encoder, - array& out, - const array& a, - const array& b, - const std::optional& c /* = std::nullopt */, - float alpha /* = 1 */, - float beta /* = 0 */) { - encoder.set_input_array(a); - encoder.set_input_array(b); - if (c) { - encoder.set_input_array(*c); - } - encoder.set_output_array(out); - - run_impl( - encoder, - out.data(), - a.data(), - b.data(), - c ? c->data() : nullptr, - alpha, - beta); -} - -} // namespace mlx::core::cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h index eccee8580..e093351b6 100644 --- a/mlx/backend/cuda/gemms/cublas_gemm.h +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -5,13 +5,13 @@ #include "mlx/backend/cuda/device.h" #include -#include -namespace mlx::core::cu { -class Matmul { +namespace mlx::core { + +class CublasGemm { public: - Matmul( - Device& device, + CublasGemm( + cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, @@ -25,8 +25,8 @@ class Matmul { int64_t a_batch_stride, int64_t b_batch_stride); - Matmul( - Device& device, + CublasGemm( + cu::Device& device, Dtype dtype, bool a_transposed, uint64_t a_rows, @@ -42,25 +42,39 @@ class Matmul { int64_t b_batch_stride, int64_t c_batch_stride); - ~Matmul(); + ~CublasGemm(); void run( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, - const std::optional& c = std::nullopt, - float alpha = 1, - float beta = 0); + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides); + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, + float alpha, + float beta); + + private: void run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides); + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides); void run_batched( cu::CommandEncoder& encoder, @@ -68,15 +82,14 @@ class Matmul { const array& a, const array& b, const array& c, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides, - const mlx::core::Strides& c_batch_strides, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, float alpha, float beta); - private: - void run_impl( + void execute( cu::CommandEncoder& encoder, void* out, const void* a, @@ -97,4 +110,4 @@ class Matmul { cublasLtMatmulHeuristicResult_t heuristic_; }; -} // namespace mlx::core::cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp similarity index 80% rename from mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp rename to mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp index 39a8a5ddd..56c731587 100644 --- a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp @@ -4,16 +4,16 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemms/cublas_gemm.h" -namespace mlx::core::cu { +namespace mlx::core { -void Matmul::run_batched( +void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides) { + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides) { encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_output_array(out); @@ -22,7 +22,7 @@ void Matmul::run_batched( ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { - run_impl( + execute( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, a.data() + a.itemsize() * a_it.loc, @@ -33,16 +33,16 @@ void Matmul::run_batched( } } -void Matmul::run_batched( +void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides, - const mlx::core::Strides& c_batch_strides, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, float alpha, float beta) { encoder.set_input_array(a); @@ -56,7 +56,7 @@ void Matmul::run_batched( ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); auto concurrent = encoder.concurrent_context(); for (size_t i = 0; i < nbatch; ++i) { - run_impl( + execute( encoder, out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, a.data() + a.itemsize() * a_it.loc, @@ -70,4 +70,4 @@ void Matmul::run_batched( } } -} // namespace mlx::core::cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu similarity index 95% rename from mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu rename to mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu index da7163b42..570b79463 100644 --- a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu @@ -6,7 +6,9 @@ #include -namespace mlx::core::cu { +namespace mlx::core { + +namespace cu { namespace cg = cooperative_groups; @@ -128,6 +130,10 @@ __global__ void set_addmm_device_pointers_g( out_start + item_size * index * batch_stride; } +} // namespace cu + +namespace { + void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( @@ -139,14 +145,16 @@ void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); } -void Matmul::run_batched( +} // namespace + +void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides) { + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides) { int batch_count = out.size() / (M_ * N_); set_pointer_mode(a_desc_, batch_count); set_pointer_mode(b_desc_, batch_count); @@ -213,7 +221,7 @@ void Matmul::run_batched( auto a_pointers = pointers.data(); auto b_pointers = a_pointers + batch_count; auto out_pointers = b_pointers + batch_count; - run_impl( + execute( encoder, reinterpret_cast(out_pointers), reinterpret_cast(a_pointers), @@ -221,16 +229,16 @@ void Matmul::run_batched( nullptr); } -void Matmul::run_batched( +void CublasGemm::run_batched( cu::CommandEncoder& encoder, array& out, const array& a, const array& b, const array& c, - const mlx::core::Shape& batch_shape, - const mlx::core::Strides& a_batch_strides, - const mlx::core::Strides& b_batch_strides, - const mlx::core::Strides& c_batch_strides, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, float alpha, float beta) { int batch_count = out.size() / (M_ * N_); @@ -306,7 +314,7 @@ void Matmul::run_batched( auto b_pointers = a_pointers + batch_count; auto c_pointers = b_pointers + batch_count; auto out_pointers = c_pointers + batch_count; - run_impl( + execute( encoder, reinterpret_cast(out_pointers), reinterpret_cast(a_pointers), @@ -316,4 +324,4 @@ void Matmul::run_batched( beta); } -} // namespace mlx::core::cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 283aaaf2e..b11fae538 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -97,7 +97,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt - cu::Matmul matmul( + CublasGemm gemm( cu::device(s.device), a.dtype(), a_transposed, @@ -111,14 +111,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { batch_shape.back(), a_batch_strides.back(), b_batch_strides.back()); - - if ((batch_count / batch_shape.back()) == 1) { - matmul.run(encoder, out, a, b); - return; - } - - matmul.run_batched( - encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); + gemm.run(encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { @@ -186,7 +179,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt - cu::Matmul matmul( + CublasGemm gemm( cu::device(s.device), a.dtype(), a_transposed, @@ -202,12 +195,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { a_batch_strides.back(), b_batch_strides.back(), c_batch_strides.back()); - - if ((batch_count / batch_shape.back()) == 1) { - matmul.run(encoder, out, a, b, c, alpha_, beta_); - return; - } - matmul.run_batched( + gemm.run( encoder, out, a,