diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 49658dcd8..0e8f64e20 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -25,7 +25,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu - ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cu similarity index 86% rename from mlx/backend/cuda/matmul.cpp rename to mlx/backend/cuda/matmul.cu index efddf2506..998b61609 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cu @@ -3,12 +3,14 @@ #include "mlx/backend/common/matmul.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/gemv.h" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" #include "mlx/utils.h" #include +#include #include #include @@ -197,6 +199,28 @@ class MatMul { encoder.stream())); } + void use_batch_pointer_mode(int batch_count) { + auto set_pointer_mode = [&batch_count](auto desc) { + auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY; + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_MODE, + &batch_mode, + sizeof(batch_mode))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch_count, + sizeof(int32_t))); + }; + set_pointer_mode(a_desc_); + set_pointer_mode(b_desc_); + if (c_desc_) { + set_pointer_mode(c_desc_); + } + set_pointer_mode(out_desc_); + } + private: cublasComputeType_t dtype_to_compute_type(Dtype dtype) { switch (dtype) { @@ -377,27 +401,54 @@ void Matmul::eval_gpu(const std::vector& inputs, array& out) { a_batch_strides.back(), b_batch_strides.back()); - encoder.set_input_array(a); - encoder.set_input_array(b); - encoder.set_output_array(out); - auto nbatch = batch_count / batch_shape.back(); - if (nbatch == 1) { + if ((batch_count / batch_shape.back()) == 1) { + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); matmul.run(encoder, out.data(), a.data(), b.data()); return; } - ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); - ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); - auto concurrent = encoder.concurrent_context(); - for (size_t i = 0; i < nbatch; ++i) { - matmul.run( - encoder, - out.data() + out.itemsize() * i * batch_shape.back() * M * N, - a.data() + a.itemsize() * a_it.loc, - b.data() + b.itemsize() * b_it.loc); - a_it.step(); - b_it.step(); - } + // If we get here use pointer mode + matmul.use_batch_pointer_mode(batch_count); + + // Launch kernel to set device offsets + auto pointers = array( + allocator::malloc(batch_count * sizeof(uint64_t) * 3), {static_cast(batch_count * 3)}, uint64); + encoder.add_temporary(pointers); + int block_size = 512; + encoder.set_output_array(pointers); + + encoder.add_kernel_node( + cu::set_mm_device_pointers, + cuda::ceil_div(pointers.size(), block_size), + block_size, + pointers.data(), + a.data(), + b.data(), + out.data(), + static_cast(out.dtype().size()), + const_param(batch_shape), + const_param(a_batch_strides), + const_param(b_batch_strides), + static_cast(M) * N, + static_cast(batch_shape.size()), + batch_count); + + // Run matmul + encoder.set_input_array(pointers); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + + auto a_pointers = pointers.data(); + auto b_pointers = a_pointers + batch_count; + auto out_pointers = b_pointers + batch_count; + matmul.run( + encoder, + reinterpret_cast(out_pointers), + reinterpret_cast(a_pointers), + reinterpret_cast(b_pointers)); } void AddMM::eval_gpu(const std::vector& inputs, array& out) { @@ -487,6 +538,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(c); encoder.set_output_array(out); + // TODO use pointer mode here as well auto nbatch = batch_count / batch_shape.back(); if (nbatch == 1) { matmul.run(