cuda batched mm

This commit is contained in:
Awni Hannun
2025-07-21 14:11:18 -07:00
parent 7d9d6ef456
commit 421014c50a
2 changed files with 70 additions and 18 deletions

View File

@@ -25,7 +25,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu

View File

@@ -3,12 +3,14 @@
#include "mlx/backend/common/matmul.h" #include "mlx/backend/common/matmul.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/gemv.h" #include "mlx/backend/cuda/gemv.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h" #include "mlx/dtype_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <fmt/format.h> #include <fmt/format.h>
#include <cooperative_groups.h>
#include <nvtx3/nvtx3.hpp> #include <nvtx3/nvtx3.hpp>
#include <numeric> #include <numeric>
@@ -197,6 +199,28 @@ class MatMul {
encoder.stream())); encoder.stream()));
} }
void use_batch_pointer_mode(int batch_count) {
auto set_pointer_mode = [&batch_count](auto desc) {
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
&batch_mode,
sizeof(batch_mode)));
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch_count,
sizeof(int32_t)));
};
set_pointer_mode(a_desc_);
set_pointer_mode(b_desc_);
if (c_desc_) {
set_pointer_mode(c_desc_);
}
set_pointer_mode(out_desc_);
}
private: private:
cublasComputeType_t dtype_to_compute_type(Dtype dtype) { cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
switch (dtype) { switch (dtype) {
@@ -377,27 +401,54 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
a_batch_strides.back(), a_batch_strides.back(),
b_batch_strides.back()); b_batch_strides.back());
if ((batch_count / batch_shape.back()) == 1) {
encoder.set_input_array(a); encoder.set_input_array(a);
encoder.set_input_array(b); encoder.set_input_array(b);
encoder.set_output_array(out); encoder.set_output_array(out);
auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) {
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>()); matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
return; return;
} }
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1); // If we get here use pointer mode
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1); matmul.use_batch_pointer_mode(batch_count);
auto concurrent = encoder.concurrent_context();
for (size_t i = 0; i < nbatch; ++i) { // Launch kernel to set device offsets
auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3), {static_cast<int>(batch_count * 3)}, uint64);
encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size),
block_size,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
static_cast<int>(out.dtype().size()),
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
static_cast<int64_t>(M) * N,
static_cast<int>(batch_shape.size()),
batch_count);
// Run matmul
encoder.set_input_array(pointers);
encoder.set_input_array(a);
encoder.set_input_array(b);
encoder.set_output_array(out);
auto a_pointers = pointers.data<int8_t*>();
auto b_pointers = a_pointers + batch_count;
auto out_pointers = b_pointers + batch_count;
matmul.run( matmul.run(
encoder, encoder,
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N, reinterpret_cast<int8_t*>(out_pointers),
a.data<int8_t>() + a.itemsize() * a_it.loc, reinterpret_cast<int8_t*>(a_pointers),
b.data<int8_t>() + b.itemsize() * b_it.loc); reinterpret_cast<int8_t*>(b_pointers));
a_it.step();
b_it.step();
}
} }
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) { void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -487,6 +538,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(c); encoder.set_input_array(c);
encoder.set_output_array(out); encoder.set_output_array(out);
// TODO use pointer mode here as well
auto nbatch = batch_count / batch_shape.back(); auto nbatch = batch_count / batch_shape.back();
if (nbatch == 1) { if (nbatch == 1) {
matmul.run( matmul.run(