diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 49658dcd8..1a394e580 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -21,7 +21,8 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu @@ -47,6 +48,14 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) +if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) + target_sources( + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) +else() + target_sources( + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) +endif() + target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) # Embed kernel sources in binary for JIT compilation. diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp new file mode 100644 index 000000000..39a8a5ddd --- /dev/null +++ b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp @@ -0,0 +1,73 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" + +namespace mlx::core::cu { + +void Matmul::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); + for (size_t i = 0; i < nbatch; ++i) { + run_impl( + encoder, + out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, + a.data() + a.itemsize() * a_it.loc, + b.data() + b.itemsize() * b_it.loc, + nullptr); + a_it.step(); + b_it.step(); + } +} + +void Matmul::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + const mlx::core::Strides& c_batch_strides, + float alpha, + float beta) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); + for (size_t i = 0; i < nbatch; ++i) { + run_impl( + encoder, + out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, + a.data() + a.itemsize() * a_it.loc, + b.data() + b.itemsize() * b_it.loc, + c.data() + c.itemsize() * c_it.loc, + alpha, + beta); + a_it.step(); + b_it.step(); + c_it.step(); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu new file mode 100644 index 000000000..4e72fdc64 --- /dev/null +++ b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu @@ -0,0 +1,206 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/kernel_utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +__global__ void set_mm_device_pointers( + int8_t** pointers, + int8_t* a_start, + int8_t* b_start, + int8_t* out_start, + int item_size, + const __grid_constant__ Shape batch_shape, + const __grid_constant__ Strides a_batch_strides, + const __grid_constant__ Strides b_batch_strides, + int64_t batch_stride, + int batch_ndim, + int batch_count) { + auto index = cg::this_grid().thread_rank(); + if (index >= batch_count) { + return; + } + auto [a_offset, b_offset] = elem_to_loc( + index, + batch_shape.data(), + a_batch_strides.data(), + b_batch_strides.data(), + batch_ndim); + pointers[index] = a_start + item_size * a_offset; + pointers[index + batch_count] = b_start + item_size * b_offset; + pointers[index + 2 * batch_count] = + out_start + item_size * index * batch_stride; +} + +__global__ void set_addmm_device_pointers( + int8_t** pointers, + int8_t* a_start, + int8_t* b_start, + int8_t* c_start, + int8_t* out_start, + int item_size, + const __grid_constant__ Shape batch_shape, + const __grid_constant__ Strides a_batch_strides, + const __grid_constant__ Strides b_batch_strides, + const __grid_constant__ Strides c_batch_strides, + int64_t batch_stride, + int batch_ndim, + int batch_count) { + auto index = cg::this_grid().thread_rank(); + if (index >= batch_count) { + return; + } + auto [a_offset, b_offset, c_offset] = elem_to_loc( + index, + batch_shape.data(), + a_batch_strides.data(), + b_batch_strides.data(), + c_batch_strides.data(), + batch_ndim); + pointers[index] = a_start + item_size * a_offset; + pointers[index + batch_count] = b_start + item_size * b_offset; + pointers[index + 2 * batch_count] = c_start + item_size * c_offset; + pointers[index + 3 * batch_count] = + out_start + item_size * index * batch_stride; +} + +void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { + auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_MODE, + &batch_mode, + sizeof(batch_mode))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); +} + +void Matmul::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides) { + auto batch_count = out.size() / (M_ * N_); + set_pointer_mode(a_desc_, batch_count); + set_pointer_mode(b_desc_, batch_count); + set_pointer_mode(out_desc_, batch_count); + + // Launch kernel to set device offsets + auto pointers = array( + allocator::malloc(batch_count * sizeof(uint64_t) * 3), + {static_cast(batch_count * 3)}, + uint64); + + encoder.add_temporary(pointers); + int block_size = 512; + encoder.set_output_array(pointers); + + encoder.add_kernel_node( + cu::set_mm_device_pointers, + cuda::ceil_div(pointers.size(), block_size), + block_size, + pointers.data(), + a.data(), + b.data(), + out.data(), + static_cast(out.dtype().size()), + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + static_cast(M_) * N_, + static_cast(batch_shape.size()), + batch_count); + + // Run matmul + encoder.set_input_array(pointers); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + auto a_pointers = pointers.data(); + auto b_pointers = a_pointers + batch_count; + auto out_pointers = b_pointers + batch_count; + run_impl( + encoder, + reinterpret_cast(out_pointers), + reinterpret_cast(a_pointers), + reinterpret_cast(b_pointers), + nullptr); +} + +void Matmul::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + const mlx::core::Strides& c_batch_strides, + float alpha, + float beta) { + auto batch_count = out.size() / (M_ * N_); + set_pointer_mode(a_desc_, batch_count); + set_pointer_mode(b_desc_, batch_count); + set_pointer_mode(c_desc_, batch_count); + set_pointer_mode(out_desc_, batch_count); + + // Launch kernel to set device offsets + auto pointers = array( + allocator::malloc(batch_count * sizeof(uint64_t) * 4), + {static_cast(batch_count * 4)}, + uint64); + + encoder.add_temporary(pointers); + int block_size = 512; + encoder.set_output_array(pointers); + encoder.add_kernel_node( + cu::set_addmm_device_pointers, + cuda::ceil_div(pointers.size(), block_size), + block_size, + pointers.data(), + a.data(), + b.data(), + c.data(), + out.data(), + static_cast(out.dtype().size()), + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + const_param(c_batch_strides), + static_cast(M_) * N_, + static_cast(batch_shape.size()), + batch_count); + + // Run matmul + encoder.set_input_array(pointers); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + auto a_pointers = pointers.data(); + auto b_pointers = a_pointers + batch_count; + auto c_pointers = b_pointers + batch_count; + auto out_pointers = c_pointers + batch_count; + run_impl( + encoder, + reinterpret_cast(out_pointers), + reinterpret_cast(a_pointers), + reinterpret_cast(b_pointers), + reinterpret_cast(c_pointers), + alpha, + beta); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp new file mode 100644 index 000000000..a0e936fd4 --- /dev/null +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -0,0 +1,282 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" +#include "mlx/utils.h" + +#include + +namespace mlx::core::cu { + +struct CublasPreference { + CublasPreference(Device& device) { + // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB + // for Hopper+: + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace + uint64_t MiB = 1024 * 1024; + uint64_t workspace_size = + device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; + + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( + pref_, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(uint64_t))); + } + + ~CublasPreference() { + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); + } + + cublasLtMatmulPreference_t pref_{nullptr}; +}; + +cublasLtMatmulPreference_t cublas_preference(Device& device) { + static CublasPreference pref(device); + return pref.pref_; +} + +cublasComputeType_t dtype_to_compute_type(Dtype dtype) { + switch (dtype) { + case float16: + return CUBLAS_COMPUTE_32F; + case bfloat16: + return CUBLAS_COMPUTE_32F; + case float32: + return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 + : CUBLAS_COMPUTE_32F; + case float64: + case complex64: + return CUBLAS_COMPUTE_64F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); + } +} + +cudaDataType_t dtype_to_cublas_type(Dtype dtype) { + switch (dtype) { + case float16: + return CUDA_R_16F; + case bfloat16: + return CUDA_R_16BF; + case float32: + return CUDA_R_32F; + case float64: + return CUDA_R_64F; + case complex64: + return CUDA_C_32F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); + } +} + +cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + cublasLtMatrixLayout_t desc; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); + cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); + if (batch_count > 1) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_count, + sizeof(int32_t))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, + sizeof(int64_t))); + } + return desc; +} + +Matmul::Matmul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride) + : handle_(device.lt_handle()), + pref_(cublas_preference(device)), + M_(a_rows), + N_(b_cols) { + heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; + + auto scale_type = dtype_to_cublas_type(dtype); + if (dtype == bfloat16 || dtype == float16) { + scale_type = CUDA_R_32F; + } + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( + &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); + int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(int32_t))); + cublasOperation_t op = CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &op, + sizeof(cublasOperation_t))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &op, + sizeof(cublasOperation_t))); + + auto type = dtype_to_cublas_type(dtype); + a_desc_ = create_matrix_layout( + type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); + b_desc_ = create_matrix_layout( + type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); + out_desc_ = create_matrix_layout( + type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); +} + +Matmul::Matmul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride) + : Matmul( + device, + dtype, + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride) { + auto type = dtype_to_cublas_type(dtype); + c_desc_ = create_matrix_layout( + type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); +} + +Matmul::~Matmul() { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); +} + +void Matmul::run_impl( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + float alpha /* = 1 */, + float beta /* = 0 */) { + if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { + int ret = 0; + CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( + handle_, + matmul_desc_, + a_desc_, + b_desc_, + out_desc_, // TODO should that be c_desc is it's set? + out_desc_, + pref_, + 1, + &heuristic_, + &ret)); + if (ret == 0) { + throw std::runtime_error("Can not find algorithm for matmul."); + } + } + + void* workspace_ptr = nullptr; + if (heuristic_.workspaceSize > 0) { + array workspace( + allocator::malloc(heuristic_.workspaceSize), + {static_cast(heuristic_.workspaceSize)}, + int8); + encoder.add_temporary(workspace); + workspace_ptr = workspace.data(); + } + + auto capture = encoder.capture_context(); + CHECK_CUBLAS_ERROR(cublasLtMatmul( + handle_, + matmul_desc_, + &alpha, + a, + a_desc_, + b, + b_desc_, + &beta, + c ? c : out, + c ? c_desc_ : out_desc_, + out, + out_desc_, + &heuristic_.algo, + workspace_ptr, + heuristic_.workspaceSize, + encoder.stream())); +} + +void Matmul::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const std::optional& c /* = std::nullopt */, + float alpha /* = 1 */, + float beta /* = 0 */) { + encoder.set_input_array(a); + encoder.set_input_array(b); + if (c) { + encoder.set_input_array(*c); + } + encoder.set_output_array(out); + + run_impl( + encoder, + out.data(), + a.data(), + b.data(), + c ? c->data() : nullptr, + alpha, + beta); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h new file mode 100644 index 000000000..eccee8580 --- /dev/null +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -0,0 +1,100 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +#include +#include + +namespace mlx::core::cu { +class Matmul { + public: + Matmul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + Matmul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride); + + ~Matmul(); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const std::optional& c = std::nullopt, + float alpha = 1, + float beta = 0); + + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides); + + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + const mlx::core::Strides& c_batch_strides, + float alpha, + float beta); + + private: + void run_impl( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + float alpha = 1, + float beta = 0); + + uint64_t M_; + uint64_t N_; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtHandle_t handle_{nullptr}; + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulHeuristicResult_t heuristic_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu similarity index 99% rename from mlx/backend/cuda/gemv.cu rename to mlx/backend/cuda/gemms/gemv.cu index fe0f7a327..b62d6e6b8 100644 --- a/mlx/backend/cuda/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/gemv.h" +#include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/gemv.h b/mlx/backend/cuda/gemms/gemv.h similarity index 100% rename from mlx/backend/cuda/gemv.h rename to mlx/backend/cuda/gemms/gemv.h diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index efddf2506..283aaaf2e 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -2,279 +2,15 @@ #include "mlx/backend/common/matmul.h" #include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/gemv.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/gpu/copy.h" -#include "mlx/dtype_utils.h" #include "mlx/primitives.h" -#include "mlx/utils.h" -#include #include - #include namespace mlx::core { - -namespace cu { - -struct CublasPreference { - CublasPreference(Device& device) { - // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB - // for Hopper+: - // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace - uint64_t MiB = 1024 * 1024; - uint64_t workspace_size = - device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; - - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( - pref_, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(uint64_t))); - } - - ~CublasPreference() { - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); - } - - cublasLtMatmulPreference_t pref_{nullptr}; -}; - -cublasLtMatmulPreference_t cublas_preference(Device& device) { - static CublasPreference pref(device); - return pref.pref_; -} - -class MatMul { - public: - MatMul( - Device& device, - Dtype dtype, - bool a_transposed, - uint64_t a_rows, - uint64_t a_cols, - int64_t lda, - bool b_transposed, - uint64_t b_rows, - uint64_t b_cols, - int64_t ldb, - int32_t batch_count, - int64_t a_batch_stride, - int64_t b_batch_stride) - : handle_(device.lt_handle()), pref_(cublas_preference(device)) { - heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - - auto scale_type = dtype_to_cuda_type(dtype); - if (dtype == bfloat16 || dtype == float16) { - scale_type = CUDA_R_32F; - } - CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( - &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); - int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, - sizeof(int32_t))); - cublasOperation_t op = CUBLAS_OP_N; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &op, - sizeof(cublasOperation_t))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, - &op, - sizeof(cublasOperation_t))); - - auto type = dtype_to_cuda_type(dtype); - a_desc_ = create_matrix_layout( - type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); - b_desc_ = create_matrix_layout( - type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); - out_desc_ = create_matrix_layout( - type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); - } - - MatMul( - Device& device, - Dtype dtype, - bool a_transposed, - uint64_t a_rows, - uint64_t a_cols, - int64_t lda, - bool b_transposed, - uint64_t b_rows, - uint64_t b_cols, - int64_t ldb, - int64_t ldc, - int32_t batch_count, - int64_t a_batch_stride, - int64_t b_batch_stride, - int64_t c_batch_stride) - : MatMul( - device, - dtype, - a_transposed, - a_rows, - a_cols, - lda, - b_transposed, - b_rows, - b_cols, - ldb, - batch_count, - a_batch_stride, - b_batch_stride) { - auto type = dtype_to_cuda_type(dtype); - c_desc_ = create_matrix_layout( - type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); - } - - ~MatMul() { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); - } - - void run( - cu::CommandEncoder& encoder, - void* out, - void* a, - void* b, - void* c = nullptr, - float alpha = 1, - float beta = 0) { - if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { - int ret = 0; - CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( - handle_, - matmul_desc_, - a_desc_, - b_desc_, - out_desc_, - out_desc_, - pref_, - 1, - &heuristic_, - &ret)); - if (ret == 0) { - throw std::runtime_error("Can not find algorithm for matmul."); - } - } - - void* workspace_ptr = nullptr; - if (heuristic_.workspaceSize > 0) { - array workspace( - allocator::malloc(heuristic_.workspaceSize), - {static_cast(heuristic_.workspaceSize)}, - int8); - encoder.add_temporary(workspace); - workspace_ptr = workspace.data(); - } - - auto capture = encoder.capture_context(); - CHECK_CUBLAS_ERROR(cublasLtMatmul( - handle_, - matmul_desc_, - &alpha, - a, - a_desc_, - b, - b_desc_, - &beta, - c ? c : out, - c ? c_desc_ : out_desc_, - out, - out_desc_, - &heuristic_.algo, - workspace_ptr, - heuristic_.workspaceSize, - encoder.stream())); - } - - private: - cublasComputeType_t dtype_to_compute_type(Dtype dtype) { - switch (dtype) { - case float16: - return CUBLAS_COMPUTE_32F; - case bfloat16: - return CUBLAS_COMPUTE_32F; - case float32: - return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 - : CUBLAS_COMPUTE_32F; - case float64: - case complex64: - return CUBLAS_COMPUTE_64F; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); - } - } - - cudaDataType_t dtype_to_cuda_type(Dtype dtype) { - switch (dtype) { - case float16: - return CUDA_R_16F; - case bfloat16: - return CUDA_R_16BF; - case float32: - return CUDA_R_32F; - case float64: - return CUDA_R_64F; - case complex64: - return CUDA_C_32F; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); - } - } - - cublasLtMatrixLayout_t create_matrix_layout( - cudaDataType_t type, - uint64_t rows, - uint64_t cols, - bool transposed, - int64_t ld, - int32_t batch_count, - int64_t batch_stride) { - cublasLtMatrixLayout_t desc; - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); - cublasLtOrder_t order = - transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); - if (batch_count > 1) { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch_count, - sizeof(int32_t))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &batch_stride, - sizeof(int64_t))); - } - return desc; - } - - cublasLtMatmulPreference_t pref_{nullptr}; - cublasLtHandle_t handle_{nullptr}; - cublasLtMatmulDesc_t matmul_desc_{nullptr}; - cublasLtMatrixLayout_t a_desc_{nullptr}; - cublasLtMatrixLayout_t b_desc_{nullptr}; - cublasLtMatrixLayout_t c_desc_{nullptr}; - cublasLtMatrixLayout_t out_desc_{nullptr}; - cublasLtMatmulHeuristicResult_t heuristic_; -}; - -} // namespace cu - namespace { std::tuple @@ -361,8 +97,7 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt - - cu::MatMul matmul( + cu::Matmul matmul( cu::device(s.device), a.dtype(), a_transposed, @@ -377,27 +112,13 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { a_batch_strides.back(), b_batch_strides.back()); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - auto nbatch = batch_count / batch_shape.back(); - if (nbatch == 1) { - matmul.run(encoder, out.data(), a.data(), b.data()); + if ((batch_count / batch_shape.back()) == 1) { + matmul.run(encoder, out, a, b); return; } - ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); - ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); - auto concurrent = encoder.concurrent_context(); - for (size_t i = 0; i < nbatch; ++i) { - matmul.run( - encoder, - out.data() + out.itemsize() * i * batch_shape.back() * M * N, - a.data() + a.itemsize() * a_it.loc, - b.data() + b.itemsize() * b_it.loc); - a_it.step(); - b_it.step(); - } + matmul.run_batched( + encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { @@ -465,7 +186,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { ///////////////////////////////////////////////////////////////////////////// // Invoke cublasLt - cu::MatMul matmul( + cu::Matmul matmul( cu::device(s.device), a.dtype(), a_transposed, @@ -482,41 +203,22 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { b_batch_strides.back(), c_batch_strides.back()); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_input_array(c); - encoder.set_output_array(out); - - auto nbatch = batch_count / batch_shape.back(); - if (nbatch == 1) { - matmul.run( - encoder, - out.data(), - a.data(), - b.data(), - c.data(), - alpha_, - beta_); + if ((batch_count / batch_shape.back()) == 1) { + matmul.run(encoder, out, a, b, c, alpha_, beta_); return; } - - ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); - ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); - ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); - auto concurrent = encoder.concurrent_context(); - for (size_t i = 0; i < nbatch; ++i) { - matmul.run( - encoder, - out.data() + out.itemsize() * i * batch_shape.back() * M * N, - a.data() + a.itemsize() * a_it.loc, - b.data() + b.itemsize() * b_it.loc, - c.data() + c.itemsize() * c_it.loc, - alpha_, - beta_); - a_it.step(); - b_it.step(); - c_it.step(); - } + matmul.run_batched( + encoder, + out, + a, + b, + c, + batch_shape, + a_batch_strides, + b_batch_strides, + c_batch_strides, + alpha_, + beta_); } } // namespace mlx::core