From e1d40be0b6ec496bf6990a130124cf7c04188028 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Jul 2025 13:27:09 -0700 Subject: [PATCH] addmm as well --- mlx/backend/cuda/CMakeLists.txt | 13 +- .../cuda/gemms/cublas_batched_gemm_12_0.cpp | 73 +++ .../cuda/gemms/cublas_batched_gemm_12_9.cu | 206 +++++++ mlx/backend/cuda/gemms/cublas_gemm.cpp | 290 +++++++++ mlx/backend/cuda/gemms/cublas_gemm.h | 104 ++++ mlx/backend/cuda/{ => gemms}/gemv.cu | 2 +- mlx/backend/cuda/{ => gemms}/gemv.h | 0 mlx/backend/cuda/matmul.cpp | 224 +++++++ mlx/backend/cuda/matmul.cu | 574 ------------------ 9 files changed, 909 insertions(+), 577 deletions(-) create mode 100644 mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp create mode 100644 mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu create mode 100644 mlx/backend/cuda/gemms/cublas_gemm.cpp create mode 100644 mlx/backend/cuda/gemms/cublas_gemm.h rename mlx/backend/cuda/{ => gemms}/gemv.cu (99%) rename mlx/backend/cuda/{ => gemms}/gemv.h (100%) create mode 100644 mlx/backend/cuda/matmul.cpp delete mode 100644 mlx/backend/cuda/matmul.cu diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 0e8f64e20..1a394e580 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -21,11 +21,12 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/gemv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu - ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu @@ -47,6 +48,14 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) +if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0) + target_sources( + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_9.cu) +else() + target_sources( + mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_batched_gemm_12_0.cpp) +endif() + target_compile_definitions(mlx PRIVATE MLX_USE_CUDA) # Embed kernel sources in binary for JIT compilation. diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp new file mode 100644 index 000000000..39a8a5ddd --- /dev/null +++ b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_0.cpp @@ -0,0 +1,73 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" + +namespace mlx::core::cu { + +void Matmul::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); + for (size_t i = 0; i < nbatch; ++i) { + run_impl( + encoder, + out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, + a.data() + a.itemsize() * a_it.loc, + b.data() + b.itemsize() * b_it.loc, + nullptr); + a_it.step(); + b_it.step(); + } +} + +void Matmul::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + const mlx::core::Strides& c_batch_strides, + float alpha, + float beta) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + auto nbatch = out.size() / (M_ * N_ * batch_shape.back()); + ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); + ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); + ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); + auto concurrent = encoder.concurrent_context(); + for (size_t i = 0; i < nbatch; ++i) { + run_impl( + encoder, + out.data() + out.itemsize() * i * batch_shape.back() * M_ * N_, + a.data() + a.itemsize() * a_it.loc, + b.data() + b.itemsize() * b_it.loc, + c.data() + c.itemsize() * c_it.loc, + alpha, + beta); + a_it.step(); + b_it.step(); + c_it.step(); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu new file mode 100644 index 000000000..4e72fdc64 --- /dev/null +++ b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu @@ -0,0 +1,206 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/kernel_utils.cuh" + +#include + +namespace mlx::core::cu { + +namespace cg = cooperative_groups; + +__global__ void set_mm_device_pointers( + int8_t** pointers, + int8_t* a_start, + int8_t* b_start, + int8_t* out_start, + int item_size, + const __grid_constant__ Shape batch_shape, + const __grid_constant__ Strides a_batch_strides, + const __grid_constant__ Strides b_batch_strides, + int64_t batch_stride, + int batch_ndim, + int batch_count) { + auto index = cg::this_grid().thread_rank(); + if (index >= batch_count) { + return; + } + auto [a_offset, b_offset] = elem_to_loc( + index, + batch_shape.data(), + a_batch_strides.data(), + b_batch_strides.data(), + batch_ndim); + pointers[index] = a_start + item_size * a_offset; + pointers[index + batch_count] = b_start + item_size * b_offset; + pointers[index + 2 * batch_count] = + out_start + item_size * index * batch_stride; +} + +__global__ void set_addmm_device_pointers( + int8_t** pointers, + int8_t* a_start, + int8_t* b_start, + int8_t* c_start, + int8_t* out_start, + int item_size, + const __grid_constant__ Shape batch_shape, + const __grid_constant__ Strides a_batch_strides, + const __grid_constant__ Strides b_batch_strides, + const __grid_constant__ Strides c_batch_strides, + int64_t batch_stride, + int batch_ndim, + int batch_count) { + auto index = cg::this_grid().thread_rank(); + if (index >= batch_count) { + return; + } + auto [a_offset, b_offset, c_offset] = elem_to_loc( + index, + batch_shape.data(), + a_batch_strides.data(), + b_batch_strides.data(), + c_batch_strides.data(), + batch_ndim); + pointers[index] = a_start + item_size * a_offset; + pointers[index + batch_count] = b_start + item_size * b_offset; + pointers[index + 2 * batch_count] = c_start + item_size * c_offset; + pointers[index + 3 * batch_count] = + out_start + item_size * index * batch_stride; +} + +void set_pointer_mode(cublasLtMatrixLayout_t desc, int batch_count) { + auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_MODE, + &batch_mode, + sizeof(batch_mode))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t))); +} + +void Matmul::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides) { + auto batch_count = out.size() / (M_ * N_); + set_pointer_mode(a_desc_, batch_count); + set_pointer_mode(b_desc_, batch_count); + set_pointer_mode(out_desc_, batch_count); + + // Launch kernel to set device offsets + auto pointers = array( + allocator::malloc(batch_count * sizeof(uint64_t) * 3), + {static_cast(batch_count * 3)}, + uint64); + + encoder.add_temporary(pointers); + int block_size = 512; + encoder.set_output_array(pointers); + + encoder.add_kernel_node( + cu::set_mm_device_pointers, + cuda::ceil_div(pointers.size(), block_size), + block_size, + pointers.data(), + a.data(), + b.data(), + out.data(), + static_cast(out.dtype().size()), + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + static_cast(M_) * N_, + static_cast(batch_shape.size()), + batch_count); + + // Run matmul + encoder.set_input_array(pointers); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + auto a_pointers = pointers.data(); + auto b_pointers = a_pointers + batch_count; + auto out_pointers = b_pointers + batch_count; + run_impl( + encoder, + reinterpret_cast(out_pointers), + reinterpret_cast(a_pointers), + reinterpret_cast(b_pointers), + nullptr); +} + +void Matmul::run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + const mlx::core::Strides& c_batch_strides, + float alpha, + float beta) { + auto batch_count = out.size() / (M_ * N_); + set_pointer_mode(a_desc_, batch_count); + set_pointer_mode(b_desc_, batch_count); + set_pointer_mode(c_desc_, batch_count); + set_pointer_mode(out_desc_, batch_count); + + // Launch kernel to set device offsets + auto pointers = array( + allocator::malloc(batch_count * sizeof(uint64_t) * 4), + {static_cast(batch_count * 4)}, + uint64); + + encoder.add_temporary(pointers); + int block_size = 512; + encoder.set_output_array(pointers); + encoder.add_kernel_node( + cu::set_addmm_device_pointers, + cuda::ceil_div(pointers.size(), block_size), + block_size, + pointers.data(), + a.data(), + b.data(), + c.data(), + out.data(), + static_cast(out.dtype().size()), + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + const_param(c_batch_strides), + static_cast(M_) * N_, + static_cast(batch_shape.size()), + batch_count); + + // Run matmul + encoder.set_input_array(pointers); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_input_array(c); + encoder.set_output_array(out); + + auto a_pointers = pointers.data(); + auto b_pointers = a_pointers + batch_count; + auto c_pointers = b_pointers + batch_count; + auto out_pointers = c_pointers + batch_count; + run_impl( + encoder, + reinterpret_cast(out_pointers), + reinterpret_cast(a_pointers), + reinterpret_cast(b_pointers), + reinterpret_cast(c_pointers), + alpha, + beta); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemms/cublas_gemm.cpp b/mlx/backend/cuda/gemms/cublas_gemm.cpp new file mode 100644 index 000000000..65cf3beb3 --- /dev/null +++ b/mlx/backend/cuda/gemms/cublas_gemm.cpp @@ -0,0 +1,290 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" +#include "mlx/utils.h" + +#include + +void check_cublas_error(const char* name, cublasStatus_t err) { + if (err != CUBLAS_STATUS_SUCCESS) { + // TODO: Use cublasGetStatusString when it is widely available. + throw std::runtime_error( + fmt::format("{} failed with code: {}.", name, static_cast(err))); + } +} + +namespace mlx::core::cu { + +struct CublasPreference { + CublasPreference(Device& device) { + // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB + // for Hopper+: + // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace + uint64_t MiB = 1024 * 1024; + uint64_t workspace_size = + device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; + + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( + pref_, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, + sizeof(uint64_t))); + } + + ~CublasPreference() { + CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); + } + + cublasLtMatmulPreference_t pref_{nullptr}; +}; + +cublasLtMatmulPreference_t cublas_preference(Device& device) { + static CublasPreference pref(device); + return pref.pref_; +} + +cublasComputeType_t dtype_to_compute_type(Dtype dtype) { + switch (dtype) { + case float16: + return CUBLAS_COMPUTE_32F; + case bfloat16: + return CUBLAS_COMPUTE_32F; + case float32: + return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 + : CUBLAS_COMPUTE_32F; + case float64: + case complex64: + return CUBLAS_COMPUTE_64F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); + } +} + +cudaDataType_t dtype_to_cublas_type(Dtype dtype) { + switch (dtype) { + case float16: + return CUDA_R_16F; + case bfloat16: + return CUDA_R_16BF; + case float32: + return CUDA_R_32F; + case float64: + return CUDA_R_64F; + case complex64: + return CUDA_C_32F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in Matmul: {}.", dtype_to_string(dtype))); + } +} + +cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride) { + cublasLtMatrixLayout_t desc; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); + cublasLtOrder_t order = transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); + if (batch_count > 1) { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_count, + sizeof(int32_t))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_stride, + sizeof(int64_t))); + } + return desc; +} + +Matmul::Matmul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride) + : handle_(device.lt_handle()), + pref_(cublas_preference(device)), + M_(a_rows), + N_(b_cols) { + heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; + + auto scale_type = dtype_to_cublas_type(dtype); + if (dtype == bfloat16 || dtype == float16) { + scale_type = CUDA_R_32F; + } + CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( + &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); + int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_POINTER_MODE, + &pointer_mode, + sizeof(int32_t))); + cublasOperation_t op = CUBLAS_OP_N; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &op, + sizeof(cublasOperation_t))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( + matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &op, + sizeof(cublasOperation_t))); + + auto type = dtype_to_cublas_type(dtype); + a_desc_ = create_matrix_layout( + type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); + b_desc_ = create_matrix_layout( + type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); + out_desc_ = create_matrix_layout( + type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); +} + +Matmul::Matmul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride) + : Matmul( + device, + dtype, + a_transposed, + a_rows, + a_cols, + lda, + b_transposed, + b_rows, + b_cols, + ldb, + batch_count, + a_batch_stride, + b_batch_stride) { + auto type = dtype_to_cublas_type(dtype); + c_desc_ = create_matrix_layout( + type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); +} + +Matmul::~Matmul() { + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); +} + +void Matmul::run_impl( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + float alpha /* = 1 */, + float beta /* = 0 */) { + if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { + int ret = 0; + CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( + handle_, + matmul_desc_, + a_desc_, + b_desc_, + out_desc_, // TODO should that be c_desc is it's set? + out_desc_, + pref_, + 1, + &heuristic_, + &ret)); + if (ret == 0) { + throw std::runtime_error("Can not find algorithm for matmul."); + } + } + + void* workspace_ptr = nullptr; + if (heuristic_.workspaceSize > 0) { + array workspace( + allocator::malloc(heuristic_.workspaceSize), + {static_cast(heuristic_.workspaceSize)}, + int8); + encoder.add_temporary(workspace); + workspace_ptr = workspace.data(); + } + + auto capture = encoder.capture_context(); + CHECK_CUBLAS_ERROR(cublasLtMatmul( + handle_, + matmul_desc_, + &alpha, + a, + a_desc_, + b, + b_desc_, + &beta, + c ? c : out, + c ? c_desc_ : out_desc_, + out, + out_desc_, + &heuristic_.algo, + workspace_ptr, + heuristic_.workspaceSize, + encoder.stream())); +} + +void Matmul::run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const std::optional& c /* = std::nullopt */, + float alpha /* = 1 */, + float beta /* = 0 */) { + encoder.set_input_array(a); + encoder.set_input_array(b); + if (c) { + encoder.set_input_array(*c); + } + encoder.set_output_array(out); + + run_impl( + encoder, + out.data(), + a.data(), + b.data(), + c ? c->data() : nullptr, + alpha, + beta); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemms/cublas_gemm.h b/mlx/backend/cuda/gemms/cublas_gemm.h new file mode 100644 index 000000000..2babb808e --- /dev/null +++ b/mlx/backend/cuda/gemms/cublas_gemm.h @@ -0,0 +1,104 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +#include +#include + +void check_cublas_error(const char* name, cublasStatus_t err); + +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) + +namespace mlx::core::cu { +class Matmul { + public: + Matmul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + Matmul( + Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride); + + ~Matmul(); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const std::optional& c = std::nullopt, + float alpha = 1, + float beta = 0); + + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides); + + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + const mlx::core::Strides& c_batch_strides, + float alpha, + float beta); + + private: + void run_impl( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + float alpha = 1, + float beta = 0); + + uint64_t M_; + uint64_t N_; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtHandle_t handle_{nullptr}; + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulHeuristicResult_t heuristic_; +}; + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu similarity index 99% rename from mlx/backend/cuda/gemv.cu rename to mlx/backend/cuda/gemms/gemv.cu index fe0f7a327..b62d6e6b8 100644 --- a/mlx/backend/cuda/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/gemv.h" +#include "mlx/backend/cuda/gemms/gemv.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/gemv.h b/mlx/backend/cuda/gemms/gemv.h similarity index 100% rename from mlx/backend/cuda/gemv.h rename to mlx/backend/cuda/gemms/gemv.h diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp new file mode 100644 index 000000000..283aaaf2e --- /dev/null +++ b/mlx/backend/cuda/matmul.cpp @@ -0,0 +1,224 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/matmul.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/gemms/cublas_gemm.h" +#include "mlx/backend/cuda/gemms/gemv.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/primitives.h" + +#include +#include + +namespace mlx::core { +namespace { + +std::tuple +check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (sty == 1 && stx == arr.shape(-1)) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy = contiguous_copy_gpu(arr, s); + enc.add_temporary(arr_copy); + return std::make_tuple(false, arr.shape(-1), arr_copy); + } +} + +} // namespace + +void Matmul::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Matmul::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + assert(inputs.size() == 2); + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + // Return 0s if either input is empty. + if (a_pre.size() == 0 || b_pre.size() == 0) { + array zero(0, a_pre.dtype()); + encoder.add_temporary(zero); + fill_gpu(zero, out, s); + return; + } + + out.set_data(allocator::malloc(out.nbytes())); + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); + + auto batch_count = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + batch_shape = {1}; + } + + if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) { + cu::gemv( + a, + b, + out, + M, + N, + K, + batch_count, + batch_shape, + a_batch_strides, + b_batch_strides, + encoder); + return; + } + + ///////////////////////////////////////////////////////////////////////////// + // Invoke cublasLt + cu::Matmul matmul( + cu::device(s.device), + a.dtype(), + a_transposed, + M, + K, + lda, + b_transposed, + K, + N, + ldb, + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back()); + + if ((batch_count / batch_shape.back()) == 1) { + matmul.run(encoder, out, a, b); + return; + } + + matmul.run_batched( + encoder, out, a, b, batch_shape, a_batch_strides, b_batch_strides); +} + +void AddMM::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("AddMM::eval_gpu"); + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + assert(inputs.size() == 3); + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + auto c = inputs[2]; + + ///////////////////////////////////////////////////////////////////////////// + // Init checks and prep + + int M = a_pre.shape(-2); + int N = b_pre.shape(-1); + int K = a_pre.shape(-1); + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); + auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); + + int64_t ldc; + { + auto stx = c.strides()[c.ndim() - 2]; + auto sty = c.strides()[c.ndim() - 1]; + if (sty == 1 && stx == c.shape(-1)) { + ldc = stx; + out.set_data(allocator::malloc(out.nbytes())); + } else if (sty == 1 && stx == 0) { + ldc = 0; + out.set_data(allocator::malloc(out.nbytes())); + } else { + // Copy C into out and set C to out + ldc = c.shape(-1); + copy_gpu(c, out, CopyType::General, s); + c = out; + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Check and collapse batch dimensions + + auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] = + collapse_batches(a, b, c); + + auto batch_count = out.size() / (M * N); + + // Collapse batches into M if needed + if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && + a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && + c_batch_strides.back() == M * c.strides()[c.ndim() - 2] && + b_batch_strides.back() == 0) { + M *= batch_shape.back(); + batch_count = 1; + + a_batch_strides = {0}; + b_batch_strides = {0}; + c_batch_strides = {0}; + batch_shape = {1}; + } + + ///////////////////////////////////////////////////////////////////////////// + // Invoke cublasLt + + cu::Matmul matmul( + cu::device(s.device), + a.dtype(), + a_transposed, + M, + K, + lda, + b_transposed, + K, + N, + ldb, + ldc, + batch_shape.back(), + a_batch_strides.back(), + b_batch_strides.back(), + c_batch_strides.back()); + + if ((batch_count / batch_shape.back()) == 1) { + matmul.run(encoder, out, a, b, c, alpha_, beta_); + return; + } + matmul.run_batched( + encoder, + out, + a, + b, + c, + batch_shape, + a_batch_strides, + b_batch_strides, + c_batch_strides, + alpha_, + beta_); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/matmul.cu b/mlx/backend/cuda/matmul.cu deleted file mode 100644 index 998b61609..000000000 --- a/mlx/backend/cuda/matmul.cu +++ /dev/null @@ -1,574 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/common/matmul.h" -#include "mlx/backend/cuda/device.h" -#include "mlx/backend/cuda/gemv.h" -#include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/gpu/copy.h" -#include "mlx/dtype_utils.h" -#include "mlx/primitives.h" -#include "mlx/utils.h" - -#include -#include -#include - -#include - -namespace mlx::core { - -namespace cu { - -struct CublasPreference { - CublasPreference(Device& device) { - // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB - // for Hopper+: - // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace - uint64_t MiB = 1024 * 1024; - uint64_t workspace_size = - device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB; - - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute( - pref_, - CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace_size, - sizeof(uint64_t))); - } - - ~CublasPreference() { - CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_)); - } - - cublasLtMatmulPreference_t pref_{nullptr}; -}; - -cublasLtMatmulPreference_t cublas_preference(Device& device) { - static CublasPreference pref(device); - return pref.pref_; -} - -class MatMul { - public: - MatMul( - Device& device, - Dtype dtype, - bool a_transposed, - uint64_t a_rows, - uint64_t a_cols, - int64_t lda, - bool b_transposed, - uint64_t b_rows, - uint64_t b_cols, - int64_t ldb, - int32_t batch_count, - int64_t a_batch_stride, - int64_t b_batch_stride) - : handle_(device.lt_handle()), pref_(cublas_preference(device)) { - heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED; - - auto scale_type = dtype_to_cuda_type(dtype); - if (dtype == bfloat16 || dtype == float16) { - scale_type = CUDA_R_32F; - } - CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate( - &matmul_desc_, dtype_to_compute_type(dtype), scale_type)); - int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_POINTER_MODE, - &pointer_mode, - sizeof(int32_t))); - cublasOperation_t op = CUBLAS_OP_N; - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &op, - sizeof(cublasOperation_t))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - matmul_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, - &op, - sizeof(cublasOperation_t))); - - auto type = dtype_to_cuda_type(dtype); - a_desc_ = create_matrix_layout( - type, a_rows, a_cols, a_transposed, lda, batch_count, a_batch_stride); - b_desc_ = create_matrix_layout( - type, b_rows, b_cols, b_transposed, ldb, batch_count, b_batch_stride); - out_desc_ = create_matrix_layout( - type, a_rows, b_cols, false, b_cols, batch_count, a_rows * b_cols); - } - - MatMul( - Device& device, - Dtype dtype, - bool a_transposed, - uint64_t a_rows, - uint64_t a_cols, - int64_t lda, - bool b_transposed, - uint64_t b_rows, - uint64_t b_cols, - int64_t ldb, - int64_t ldc, - int32_t batch_count, - int64_t a_batch_stride, - int64_t b_batch_stride, - int64_t c_batch_stride) - : MatMul( - device, - dtype, - a_transposed, - a_rows, - a_cols, - lda, - b_transposed, - b_rows, - b_cols, - ldb, - batch_count, - a_batch_stride, - b_batch_stride) { - auto type = dtype_to_cuda_type(dtype); - c_desc_ = create_matrix_layout( - type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); - } - - ~MatMul() { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_)); - } - - void run( - cu::CommandEncoder& encoder, - void* out, - void* a, - void* b, - void* c = nullptr, - float alpha = 1, - float beta = 0) { - if (heuristic_.state != CUBLAS_STATUS_SUCCESS) { - int ret = 0; - CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic( - handle_, - matmul_desc_, - a_desc_, - b_desc_, - out_desc_, - out_desc_, - pref_, - 1, - &heuristic_, - &ret)); - if (ret == 0) { - throw std::runtime_error("Can not find algorithm for matmul."); - } - } - - void* workspace_ptr = nullptr; - if (heuristic_.workspaceSize > 0) { - array workspace( - allocator::malloc(heuristic_.workspaceSize), - {static_cast(heuristic_.workspaceSize)}, - int8); - encoder.add_temporary(workspace); - workspace_ptr = workspace.data(); - } - - auto capture = encoder.capture_context(); - CHECK_CUBLAS_ERROR(cublasLtMatmul( - handle_, - matmul_desc_, - &alpha, - a, - a_desc_, - b, - b_desc_, - &beta, - c ? c : out, - c ? c_desc_ : out_desc_, - out, - out_desc_, - &heuristic_.algo, - workspace_ptr, - heuristic_.workspaceSize, - encoder.stream())); - } - - void use_batch_pointer_mode(int batch_count) { - auto set_pointer_mode = [&batch_count](auto desc) { - auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_MODE, - &batch_mode, - sizeof(batch_mode))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch_count, - sizeof(int32_t))); - }; - set_pointer_mode(a_desc_); - set_pointer_mode(b_desc_); - if (c_desc_) { - set_pointer_mode(c_desc_); - } - set_pointer_mode(out_desc_); - } - - private: - cublasComputeType_t dtype_to_compute_type(Dtype dtype) { - switch (dtype) { - case float16: - return CUBLAS_COMPUTE_32F; - case bfloat16: - return CUBLAS_COMPUTE_32F; - case float32: - return mlx::core::env::enable_tf32() ? CUBLAS_COMPUTE_32F_FAST_TF32 - : CUBLAS_COMPUTE_32F; - case float64: - case complex64: - return CUBLAS_COMPUTE_64F; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); - } - } - - cudaDataType_t dtype_to_cuda_type(Dtype dtype) { - switch (dtype) { - case float16: - return CUDA_R_16F; - case bfloat16: - return CUDA_R_16BF; - case float32: - return CUDA_R_32F; - case float64: - return CUDA_R_64F; - case complex64: - return CUDA_C_32F; - default: - throw std::runtime_error(fmt::format( - "Unsupported dtype in MatMul: {}.", dtype_to_string(dtype))); - } - } - - cublasLtMatrixLayout_t create_matrix_layout( - cudaDataType_t type, - uint64_t rows, - uint64_t cols, - bool transposed, - int64_t ld, - int32_t batch_count, - int64_t batch_stride) { - cublasLtMatrixLayout_t desc; - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld)); - cublasLtOrder_t order = - transposed ? CUBLASLT_ORDER_COL : CUBLASLT_ORDER_ROW; - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order, sizeof(cublasLtOrder_t))); - if (batch_count > 1) { - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch_count, - sizeof(int32_t))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &batch_stride, - sizeof(int64_t))); - } - return desc; - } - - cublasLtMatmulPreference_t pref_{nullptr}; - cublasLtHandle_t handle_{nullptr}; - cublasLtMatmulDesc_t matmul_desc_{nullptr}; - cublasLtMatrixLayout_t a_desc_{nullptr}; - cublasLtMatrixLayout_t b_desc_{nullptr}; - cublasLtMatrixLayout_t c_desc_{nullptr}; - cublasLtMatrixLayout_t out_desc_{nullptr}; - cublasLtMatmulHeuristicResult_t heuristic_; -}; - -} // namespace cu - -namespace { - -std::tuple -check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) { - auto stx = arr.strides()[arr.ndim() - 2]; - auto sty = arr.strides()[arr.ndim() - 1]; - if (sty == 1 && stx == arr.shape(-1)) { - return std::make_tuple(false, stx, arr); - } else if (stx == 1 && sty == arr.shape(-2)) { - return std::make_tuple(true, sty, arr); - } else { - array arr_copy = contiguous_copy_gpu(arr, s); - enc.add_temporary(arr_copy); - return std::make_tuple(false, arr.shape(-1), arr_copy); - } -} - -} // namespace - -void Matmul::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("Matmul::eval_gpu"); - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); - - assert(inputs.size() == 2); - auto& a_pre = inputs[0]; - auto& b_pre = inputs[1]; - // Return 0s if either input is empty. - if (a_pre.size() == 0 || b_pre.size() == 0) { - array zero(0, a_pre.dtype()); - encoder.add_temporary(zero); - fill_gpu(zero, out, s); - return; - } - - out.set_data(allocator::malloc(out.nbytes())); - - ///////////////////////////////////////////////////////////////////////////// - // Init checks and prep - - int M = a_pre.shape(-2); - int N = b_pre.shape(-1); - int K = a_pre.shape(-1); - - // Keep a vector with copies to be cleared in the completed buffer to release - // the arrays - auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); - auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); - - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - - auto [batch_shape, a_batch_strides, b_batch_strides] = collapse_batches(a, b); - - auto batch_count = out.size() / (M * N); - - // Collapse batches into M if needed - if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && - b_batch_strides.back() == 0) { - M *= batch_shape.back(); - batch_count = 1; - - a_batch_strides = {0}; - b_batch_strides = {0}; - batch_shape = {1}; - } - - if (cu::can_use_gemv(M, N, K, a_transposed, b_transposed)) { - cu::gemv( - a, - b, - out, - M, - N, - K, - batch_count, - batch_shape, - a_batch_strides, - b_batch_strides, - encoder); - return; - } - - ///////////////////////////////////////////////////////////////////////////// - // Invoke cublasLt - - cu::MatMul matmul( - cu::device(s.device), - a.dtype(), - a_transposed, - M, - K, - lda, - b_transposed, - K, - N, - ldb, - batch_shape.back(), - a_batch_strides.back(), - b_batch_strides.back()); - - if ((batch_count / batch_shape.back()) == 1) { - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - matmul.run(encoder, out.data(), a.data(), b.data()); - return; - } - - // If we get here use pointer mode - matmul.use_batch_pointer_mode(batch_count); - - // Launch kernel to set device offsets - auto pointers = array( - allocator::malloc(batch_count * sizeof(uint64_t) * 3), {static_cast(batch_count * 3)}, uint64); - encoder.add_temporary(pointers); - int block_size = 512; - encoder.set_output_array(pointers); - - encoder.add_kernel_node( - cu::set_mm_device_pointers, - cuda::ceil_div(pointers.size(), block_size), - block_size, - pointers.data(), - a.data(), - b.data(), - out.data(), - static_cast(out.dtype().size()), - const_param(batch_shape), - const_param(a_batch_strides), - const_param(b_batch_strides), - static_cast(M) * N, - static_cast(batch_shape.size()), - batch_count); - - // Run matmul - encoder.set_input_array(pointers); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - - auto a_pointers = pointers.data(); - auto b_pointers = a_pointers + batch_count; - auto out_pointers = b_pointers + batch_count; - matmul.run( - encoder, - reinterpret_cast(out_pointers), - reinterpret_cast(a_pointers), - reinterpret_cast(b_pointers)); -} - -void AddMM::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("AddMM::eval_gpu"); - auto& s = stream(); - auto& encoder = cu::get_command_encoder(s); - - assert(inputs.size() == 3); - auto& a_pre = inputs[0]; - auto& b_pre = inputs[1]; - auto c = inputs[2]; - - ///////////////////////////////////////////////////////////////////////////// - // Init checks and prep - - int M = a_pre.shape(-2); - int N = b_pre.shape(-1); - int K = a_pre.shape(-1); - - // Keep a vector with copies to be cleared in the completed buffer to release - // the arrays - auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); - auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); - - int64_t ldc; - { - auto stx = c.strides()[c.ndim() - 2]; - auto sty = c.strides()[c.ndim() - 1]; - if (sty == 1 && stx == c.shape(-1)) { - ldc = stx; - out.set_data(allocator::malloc(out.nbytes())); - } else if (sty == 1 && stx == 0) { - ldc = 0; - out.set_data(allocator::malloc(out.nbytes())); - } else { - // Copy C into out and set C to out - ldc = c.shape(-1); - copy_gpu(c, out, CopyType::General, s); - c = out; - } - } - - ///////////////////////////////////////////////////////////////////////////// - // Check and collapse batch dimensions - - auto [batch_shape, a_batch_strides, b_batch_strides, c_batch_strides] = - collapse_batches(a, b, c); - - auto batch_count = out.size() / (M * N); - - // Collapse batches into M if needed - if (batch_count > 1 && !a_transposed && batch_shape.size() == 1 && - a.strides()[a.ndim() - 2] == K && a_batch_strides.back() == M * K && - c_batch_strides.back() == M * c.strides()[c.ndim() - 2] && - b_batch_strides.back() == 0) { - M *= batch_shape.back(); - batch_count = 1; - - a_batch_strides = {0}; - b_batch_strides = {0}; - c_batch_strides = {0}; - batch_shape = {1}; - } - - ///////////////////////////////////////////////////////////////////////////// - // Invoke cublasLt - - cu::MatMul matmul( - cu::device(s.device), - a.dtype(), - a_transposed, - M, - K, - lda, - b_transposed, - K, - N, - ldb, - ldc, - batch_shape.back(), - a_batch_strides.back(), - b_batch_strides.back(), - c_batch_strides.back()); - - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_input_array(c); - encoder.set_output_array(out); - - // TODO use pointer mode here as well - auto nbatch = batch_count / batch_shape.back(); - if (nbatch == 1) { - matmul.run( - encoder, - out.data(), - a.data(), - b.data(), - c.data(), - alpha_, - beta_); - return; - } - - ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); - ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); - ContiguousIterator c_it(batch_shape, c_batch_strides, batch_shape.size() - 1); - auto concurrent = encoder.concurrent_context(); - for (size_t i = 0; i < nbatch; ++i) { - matmul.run( - encoder, - out.data() + out.itemsize() * i * batch_shape.back() * M * N, - a.data() + a.itemsize() * a_it.loc, - b.data() + b.itemsize() * b_it.loc, - c.data() + c.itemsize() * c_it.loc, - alpha_, - beta_); - a_it.step(); - b_it.step(); - c_it.step(); - } -} - -} // namespace mlx::core