mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
cuda batched mm
This commit is contained in:
@@ -25,7 +25,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
|
||||||
|
|||||||
@@ -3,12 +3,14 @@
|
|||||||
#include "mlx/backend/common/matmul.h"
|
#include "mlx/backend/common/matmul.h"
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/gemv.h"
|
#include "mlx/backend/cuda/gemv.h"
|
||||||
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
#include <cooperative_groups.h>
|
||||||
#include <nvtx3/nvtx3.hpp>
|
#include <nvtx3/nvtx3.hpp>
|
||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
@@ -197,6 +199,28 @@ class MatMul {
|
|||||||
encoder.stream()));
|
encoder.stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void use_batch_pointer_mode(int batch_count) {
|
||||||
|
auto set_pointer_mode = [&batch_count](auto desc) {
|
||||||
|
auto batch_mode = CUBLASLT_BATCH_MODE_POINTER_ARRAY;
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_BATCH_MODE,
|
||||||
|
&batch_mode,
|
||||||
|
sizeof(batch_mode)));
|
||||||
|
CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
|
||||||
|
desc,
|
||||||
|
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
|
||||||
|
&batch_count,
|
||||||
|
sizeof(int32_t)));
|
||||||
|
};
|
||||||
|
set_pointer_mode(a_desc_);
|
||||||
|
set_pointer_mode(b_desc_);
|
||||||
|
if (c_desc_) {
|
||||||
|
set_pointer_mode(c_desc_);
|
||||||
|
}
|
||||||
|
set_pointer_mode(out_desc_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
cublasComputeType_t dtype_to_compute_type(Dtype dtype) {
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
@@ -377,27 +401,54 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
b_batch_strides.back());
|
b_batch_strides.back());
|
||||||
|
|
||||||
encoder.set_input_array(a);
|
if ((batch_count / batch_shape.back()) == 1) {
|
||||||
encoder.set_input_array(b);
|
encoder.set_input_array(a);
|
||||||
encoder.set_output_array(out);
|
encoder.set_input_array(b);
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
encoder.set_output_array(out);
|
||||||
if (nbatch == 1) {
|
|
||||||
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
matmul.run(encoder, out.data<int8_t>(), a.data<int8_t>(), b.data<int8_t>());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ContiguousIterator a_it(batch_shape, a_batch_strides, batch_shape.size() - 1);
|
// If we get here use pointer mode
|
||||||
ContiguousIterator b_it(batch_shape, b_batch_strides, batch_shape.size() - 1);
|
matmul.use_batch_pointer_mode(batch_count);
|
||||||
auto concurrent = encoder.concurrent_context();
|
|
||||||
for (size_t i = 0; i < nbatch; ++i) {
|
// Launch kernel to set device offsets
|
||||||
matmul.run(
|
auto pointers = array(
|
||||||
encoder,
|
allocator::malloc(batch_count * sizeof(uint64_t) * 3), {static_cast<int>(batch_count * 3)}, uint64);
|
||||||
out.data<int8_t>() + out.itemsize() * i * batch_shape.back() * M * N,
|
encoder.add_temporary(pointers);
|
||||||
a.data<int8_t>() + a.itemsize() * a_it.loc,
|
int block_size = 512;
|
||||||
b.data<int8_t>() + b.itemsize() * b_it.loc);
|
encoder.set_output_array(pointers);
|
||||||
a_it.step();
|
|
||||||
b_it.step();
|
encoder.add_kernel_node(
|
||||||
}
|
cu::set_mm_device_pointers,
|
||||||
|
cuda::ceil_div(pointers.size(), block_size),
|
||||||
|
block_size,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
static_cast<int>(out.dtype().size()),
|
||||||
|
const_param(batch_shape),
|
||||||
|
const_param(a_batch_strides),
|
||||||
|
const_param(b_batch_strides),
|
||||||
|
static_cast<int64_t>(M) * N,
|
||||||
|
static_cast<int>(batch_shape.size()),
|
||||||
|
batch_count);
|
||||||
|
|
||||||
|
// Run matmul
|
||||||
|
encoder.set_input_array(pointers);
|
||||||
|
encoder.set_input_array(a);
|
||||||
|
encoder.set_input_array(b);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
auto a_pointers = pointers.data<int8_t*>();
|
||||||
|
auto b_pointers = a_pointers + batch_count;
|
||||||
|
auto out_pointers = b_pointers + batch_count;
|
||||||
|
matmul.run(
|
||||||
|
encoder,
|
||||||
|
reinterpret_cast<int8_t*>(out_pointers),
|
||||||
|
reinterpret_cast<int8_t*>(a_pointers),
|
||||||
|
reinterpret_cast<int8_t*>(b_pointers));
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -487,6 +538,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_input_array(c);
|
encoder.set_input_array(c);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
|
|
||||||
|
// TODO use pointer mode here as well
|
||||||
auto nbatch = batch_count / batch_shape.back();
|
auto nbatch = batch_count / batch_shape.back();
|
||||||
if (nbatch == 1) {
|
if (nbatch == 1) {
|
||||||
matmul.run(
|
matmul.run(
|
||||||
Reference in New Issue
Block a user