mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 11:38:06 +08:00
[CUDA] Optimize set_mm_device_pointers for small ndim (#2473)
This commit is contained in:
parent
56be773610
commit
aa7b47481a
@ -10,7 +10,34 @@ namespace mlx::core::cu {
|
|||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
__global__ void set_mm_device_pointers(
|
template <int NDIM>
|
||||||
|
__global__ void set_mm_device_pointers_nd(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data());
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void set_mm_device_pointers_g(
|
||||||
int8_t** pointers,
|
int8_t** pointers,
|
||||||
int8_t* a_start,
|
int8_t* a_start,
|
||||||
int8_t* b_start,
|
int8_t* b_start,
|
||||||
@ -38,7 +65,38 @@ __global__ void set_mm_device_pointers(
|
|||||||
out_start + item_size * index * batch_stride;
|
out_start + item_size * index * batch_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void set_addmm_device_pointers(
|
template <int NDIM>
|
||||||
|
__global__ void set_addmm_device_pointers_nd(
|
||||||
|
int8_t** pointers,
|
||||||
|
int8_t* a_start,
|
||||||
|
int8_t* b_start,
|
||||||
|
int8_t* c_start,
|
||||||
|
int8_t* out_start,
|
||||||
|
int item_size,
|
||||||
|
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
|
||||||
|
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
|
||||||
|
int64_t batch_stride,
|
||||||
|
int batch_count) {
|
||||||
|
auto index = cg::this_grid().thread_rank();
|
||||||
|
if (index >= batch_count) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
|
||||||
|
index,
|
||||||
|
batch_shape.data(),
|
||||||
|
a_batch_strides.data(),
|
||||||
|
b_batch_strides.data(),
|
||||||
|
c_batch_strides.data());
|
||||||
|
pointers[index] = a_start + item_size * a_offset;
|
||||||
|
pointers[index + batch_count] = b_start + item_size * b_offset;
|
||||||
|
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
|
||||||
|
pointers[index + 3 * batch_count] =
|
||||||
|
out_start + item_size * index * batch_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void set_addmm_device_pointers_g(
|
||||||
int8_t** pointers,
|
int8_t** pointers,
|
||||||
int8_t* a_start,
|
int8_t* a_start,
|
||||||
int8_t* b_start,
|
int8_t* b_start,
|
||||||
@ -89,37 +147,62 @@ void Matmul::run_batched(
|
|||||||
const mlx::core::Shape& batch_shape,
|
const mlx::core::Shape& batch_shape,
|
||||||
const mlx::core::Strides& a_batch_strides,
|
const mlx::core::Strides& a_batch_strides,
|
||||||
const mlx::core::Strides& b_batch_strides) {
|
const mlx::core::Strides& b_batch_strides) {
|
||||||
auto batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
set_pointer_mode(out_desc_, batch_count);
|
set_pointer_mode(out_desc_, batch_count);
|
||||||
|
|
||||||
// Launch kernel to set device offsets
|
// Launch kernel to set device offsets
|
||||||
auto pointers = array(
|
auto pointers = array(
|
||||||
allocator::malloc(batch_count * sizeof(uint64_t) * 3),
|
allocator::malloc(batch_count * sizeof(void*) * 3),
|
||||||
{static_cast<int>(batch_count * 3)},
|
{batch_count * 3},
|
||||||
uint64);
|
uint64);
|
||||||
|
|
||||||
encoder.add_temporary(pointers);
|
encoder.add_temporary(pointers);
|
||||||
int block_size = 512;
|
|
||||||
encoder.set_output_array(pointers);
|
encoder.set_output_array(pointers);
|
||||||
|
|
||||||
encoder.add_kernel_node(
|
int block_dims = std::min(batch_count, 256);
|
||||||
cu::set_mm_device_pointers,
|
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
int64_t batch_stride = M_ * N_;
|
||||||
block_size,
|
int item_size = out.itemsize();
|
||||||
0,
|
|
||||||
pointers.data<int8_t*>(),
|
int ndim = batch_shape.size();
|
||||||
a.data<int8_t>(),
|
if (ndim <= 3) {
|
||||||
b.data<int8_t>(),
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
out.data<int8_t>(),
|
encoder.add_kernel_node(
|
||||||
static_cast<int>(out.dtype().size()),
|
cu::set_mm_device_pointers_nd<ndim_constant()>,
|
||||||
const_param(batch_shape),
|
num_blocks,
|
||||||
const_param(a_batch_strides),
|
block_dims,
|
||||||
const_param(b_batch_strides),
|
0,
|
||||||
static_cast<int64_t>(M_) * N_,
|
pointers.data<int8_t*>(),
|
||||||
static_cast<int>(batch_shape.size()),
|
a.data<int8_t>(),
|
||||||
batch_count);
|
b.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
|
const_param<ndim_constant()>(batch_shape),
|
||||||
|
const_param<ndim_constant()>(a_batch_strides),
|
||||||
|
const_param<ndim_constant()>(b_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
batch_count);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_mm_device_pointers_g,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
|
const_param(batch_shape),
|
||||||
|
const_param(a_batch_strides),
|
||||||
|
const_param(b_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
ndim,
|
||||||
|
batch_count);
|
||||||
|
}
|
||||||
|
|
||||||
// Run matmul
|
// Run matmul
|
||||||
encoder.set_input_array(pointers);
|
encoder.set_input_array(pointers);
|
||||||
@ -150,7 +233,7 @@ void Matmul::run_batched(
|
|||||||
const mlx::core::Strides& c_batch_strides,
|
const mlx::core::Strides& c_batch_strides,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta) {
|
float beta) {
|
||||||
auto batch_count = out.size() / (M_ * N_);
|
int batch_count = out.size() / (M_ * N_);
|
||||||
set_pointer_mode(a_desc_, batch_count);
|
set_pointer_mode(a_desc_, batch_count);
|
||||||
set_pointer_mode(b_desc_, batch_count);
|
set_pointer_mode(b_desc_, batch_count);
|
||||||
set_pointer_mode(c_desc_, batch_count);
|
set_pointer_mode(c_desc_, batch_count);
|
||||||
@ -159,30 +242,58 @@ void Matmul::run_batched(
|
|||||||
// Launch kernel to set device offsets
|
// Launch kernel to set device offsets
|
||||||
auto pointers = array(
|
auto pointers = array(
|
||||||
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
allocator::malloc(batch_count * sizeof(uint64_t) * 4),
|
||||||
{static_cast<int>(batch_count * 4)},
|
{batch_count * 4},
|
||||||
uint64);
|
uint64);
|
||||||
|
|
||||||
encoder.add_temporary(pointers);
|
encoder.add_temporary(pointers);
|
||||||
int block_size = 512;
|
|
||||||
encoder.set_output_array(pointers);
|
encoder.set_output_array(pointers);
|
||||||
encoder.add_kernel_node(
|
|
||||||
cu::set_addmm_device_pointers,
|
int block_dims = std::min(batch_count, 256);
|
||||||
cuda::ceil_div(pointers.size(), block_size),
|
int num_blocks = cuda::ceil_div(batch_count, block_dims);
|
||||||
block_size,
|
int64_t batch_stride = M_ * N_;
|
||||||
0,
|
int item_size = out.itemsize();
|
||||||
pointers.data<int8_t*>(),
|
|
||||||
a.data<int8_t>(),
|
int ndim = batch_shape.size();
|
||||||
b.data<int8_t>(),
|
if (ndim <= 3) {
|
||||||
c.data<int8_t>(),
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
out.data<int8_t>(),
|
encoder.add_kernel_node(
|
||||||
static_cast<int>(out.dtype().size()),
|
cu::set_addmm_device_pointers_nd<ndim_constant()>,
|
||||||
const_param(batch_shape),
|
num_blocks,
|
||||||
const_param(a_batch_strides),
|
block_dims,
|
||||||
const_param(b_batch_strides),
|
0,
|
||||||
const_param(c_batch_strides),
|
pointers.data<int8_t*>(),
|
||||||
static_cast<int64_t>(M_) * N_,
|
a.data<int8_t>(),
|
||||||
static_cast<int>(batch_shape.size()),
|
b.data<int8_t>(),
|
||||||
batch_count);
|
c.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
|
const_param<ndim_constant()>(batch_shape),
|
||||||
|
const_param<ndim_constant()>(a_batch_strides),
|
||||||
|
const_param<ndim_constant()>(b_batch_strides),
|
||||||
|
const_param<ndim_constant()>(c_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
batch_count);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
encoder.add_kernel_node(
|
||||||
|
cu::set_addmm_device_pointers_g,
|
||||||
|
num_blocks,
|
||||||
|
block_dims,
|
||||||
|
0,
|
||||||
|
pointers.data<int8_t*>(),
|
||||||
|
a.data<int8_t>(),
|
||||||
|
b.data<int8_t>(),
|
||||||
|
c.data<int8_t>(),
|
||||||
|
out.data<int8_t>(),
|
||||||
|
item_size,
|
||||||
|
const_param(batch_shape),
|
||||||
|
const_param(a_batch_strides),
|
||||||
|
const_param(b_batch_strides),
|
||||||
|
const_param(c_batch_strides),
|
||||||
|
batch_stride,
|
||||||
|
ndim,
|
||||||
|
batch_count);
|
||||||
|
}
|
||||||
|
|
||||||
// Run matmul
|
// Run matmul
|
||||||
encoder.set_input_array(pointers);
|
encoder.set_input_array(pointers);
|
||||||
|
Loading…
Reference in New Issue
Block a user