mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
cuda batched mm
This commit is contained in:
@@ -25,7 +25,7 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
#include "mlx/backend/common/matmul.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/gemv.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <cooperative_groups.h>
|
||||
#include <nvtx3/nvtx3.hpp>
|
||||
|
||||
#include <numeric>
|
||||
@@ -197,6 +199,28 @@ class MatMul {
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
void use_batch_pointer_mode(int batch_count) {
|
||||
auto set_pointer_mode = [&batch_count](auto desc) {
|
||||
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
||||
&batch_mode,
|
||||
sizeof(batch_mode)));
|
||||
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||
desc,
|
||||
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&batch_count,
|
||||
sizeof(int32_t)));
|
||||
};
|
||||
set_pointer_mode(a_desc_);
|
||||
set_pointer_mode(b_desc_);
|
||||
if (c_desc_) {
|
||||
set_pointer_mode(c_desc_);
|
||||
}
|
||||
set_pointer_mode(out_desc_);
|
||||
}
|
||||
|
||||
private:
|
||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||
switch (dtype) {
|
||||
@@ -377,27 +401,54 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
a_batch_strides.back(),
|
||||
b_batch_strides.back());
|
||||
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
auto nbatch = batch_count / batch_shape.back();
|
||||
if (nbatch == 1) {
|
||||
if ((batch_count / batch_shape.back()) == 1) {
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||
return;
|
||||
}
|
||||
|
||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
||||
auto concurrent = encoder.concurrent_context();
|
||||
for (size_t i = 0; i < nbatch; ++i) {
|
||||
matmul.run(
|
||||
encoder,
|
||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
||||
b.data<int8_t>() + b.itemsize() * b_it.loc);
|
||||
a_it.step();
|
||||
b_it.step();
|
||||
}
|
||||
// If we get here use pointer mode
|
||||
matmul.use_batch_pointer_mode(batch_count);
|
||||
|
||||
// Launch kernel to set device offsets
|
||||
auto pointers = array(
|
||||
allocator::malloc(batch_count * sizeof(uint64_t) * 3), {static_cast<int>(batch_count * 3)}, uint64);
|
||||
encoder.add_temporary(pointers);
|
||||
int block_size = 512;
|
||||
encoder.set_output_array(pointers);
|
||||
|
||||
encoder.add_kernel_node(
|
||||
cu::set_mm_device_pointers,
|
||||
cuda::ceil_div(pointers.size(), block_size),
|
||||
block_size,
|
||||
pointers.data<int8_t*>(),
|
||||
a.data<int8_t>(),
|
||||
b.data<int8_t>(),
|
||||
out.data<int8_t>(),
|
||||
static_cast<int>(out.dtype().size()),
|
||||
const_param(batch_shape),
|
||||
const_param(a_batch_strides),
|
||||
const_param(b_batch_strides),
|
||||
static_cast<int64_t>(M) * N,
|
||||
static_cast<int>(batch_shape.size()),
|
||||
batch_count);
|
||||
|
||||
// Run matmul
|
||||
encoder.set_input_array(pointers);
|
||||
encoder.set_input_array(a);
|
||||
encoder.set_input_array(b);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto a_pointers = pointers.data<int8_t*>();
|
||||
auto b_pointers = a_pointers + batch_count;
|
||||
auto out_pointers = b_pointers + batch_count;
|
||||
matmul.run(
|
||||
encoder,
|
||||
reinterpret_cast<int8_t*>(out_pointers),
|
||||
reinterpret_cast<int8_t*>(a_pointers),
|
||||
reinterpret_cast<int8_t*>(b_pointers));
|
||||
}
|
||||
|
||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -487,6 +538,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
encoder.set_input_array(c);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
// TODO use pointer mode here as well
|
||||
auto nbatch = batch_count / batch_shape.back();
|
||||
if (nbatch == 1) {
|
||||
matmul.run(
|
||||
Reference in New Issue
Block a user