Compare commits

...

2 Commits

Author SHA1 Message Date
Cheng
7fde1b6a1e Fix logsumexp/softmax not fused for some cases (#2474) 2025-08-08 14:07:17 -07:00
Cheng
aa7b47481a [CUDA] Optimize set_mm_device_pointers for small ndim (#2473) 2025-08-08 15:23:30 +09:00
2 changed files with 180 additions and 48 deletions

View File

@@ -10,7 +10,34 @@ namespace mlx::core::cu {
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
__global__ void set_mm_device_pointers( template <int NDIM>
__global__ void set_mm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_mm_device_pointers_g(
int8_t** pointers, int8_t** pointers,
int8_t* a_start, int8_t* a_start,
int8_t* b_start, int8_t* b_start,
@@ -38,7 +65,38 @@ __global__ void set_mm_device_pointers(
out_start + item_size * index * batch_stride; out_start + item_size * index * batch_stride;
} }
__global__ void set_addmm_device_pointers( template <int NDIM>
__global__ void set_addmm_device_pointers_nd(
int8_t** pointers,
int8_t* a_start,
int8_t* b_start,
int8_t* c_start,
int8_t* out_start,
int item_size,
const __grid_constant__ cuda::std::array<int32_t, NDIM> batch_shape,
const __grid_constant__ cuda::std::array<int64_t, NDIM> a_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> b_batch_strides,
const __grid_constant__ cuda::std::array<int64_t, NDIM> c_batch_strides,
int64_t batch_stride,
int batch_count) {
auto index = cg::this_grid().thread_rank();
if (index >= batch_count) {
return;
}
auto [a_offset, b_offset, c_offset] = elem_to_loc_nd<NDIM>(
index,
batch_shape.data(),
a_batch_strides.data(),
b_batch_strides.data(),
c_batch_strides.data());
pointers[index] = a_start + item_size * a_offset;
pointers[index + batch_count] = b_start + item_size * b_offset;
pointers[index + 2 * batch_count] = c_start + item_size * c_offset;
pointers[index + 3 * batch_count] =
out_start + item_size * index * batch_stride;
}
__global__ void set_addmm_device_pointers_g(
int8_t** pointers, int8_t** pointers,
int8_t* a_start, int8_t* a_start,
int8_t* b_start, int8_t* b_start,
@@ -89,37 +147,62 @@ void Matmul::run_batched(
const mlx::core::Shape& batch_shape, const mlx::core::Shape& batch_shape,
const mlx::core::Strides& a_batch_strides, const mlx::core::Strides& a_batch_strides,
const mlx::core::Strides& b_batch_strides) { const mlx::core::Strides& b_batch_strides) {
auto batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(out_desc_, batch_count); set_pointer_mode(out_desc_, batch_count);
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 3), allocator::malloc(batch_count * sizeof(void*) * 3),
{static_cast<int>(batch_count * 3)}, {batch_count * 3},
uint64); uint64);
encoder.add_temporary(pointers); encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers); encoder.set_output_array(pointers);
encoder.add_kernel_node( int block_dims = std::min(batch_count, 256);
cu::set_mm_device_pointers, int num_blocks = cuda::ceil_div(batch_count, block_dims);
cuda::ceil_div(pointers.size(), block_size), int64_t batch_stride = M_ * N_;
block_size, int item_size = out.itemsize();
0,
pointers.data<int8_t*>(), int ndim = batch_shape.size();
a.data<int8_t>(), if (ndim <= 3) {
b.data<int8_t>(), dispatch_1_2_3(ndim, [&](auto ndim_constant) {
out.data<int8_t>(), encoder.add_kernel_node(
static_cast<int>(out.dtype().size()), cu::set_mm_device_pointers_nd<ndim_constant()>,
const_param(batch_shape), num_blocks,
const_param(a_batch_strides), block_dims,
const_param(b_batch_strides), 0,
static_cast<int64_t>(M_) * N_, pointers.data<int8_t*>(),
static_cast<int>(batch_shape.size()), a.data<int8_t>(),
batch_count); b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_mm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul // Run matmul
encoder.set_input_array(pointers); encoder.set_input_array(pointers);
@@ -150,7 +233,7 @@ void Matmul::run_batched(
const mlx::core::Strides& c_batch_strides, const mlx::core::Strides& c_batch_strides,
float alpha, float alpha,
float beta) { float beta) {
auto batch_count = out.size() / (M_ * N_); int batch_count = out.size() / (M_ * N_);
set_pointer_mode(a_desc_, batch_count); set_pointer_mode(a_desc_, batch_count);
set_pointer_mode(b_desc_, batch_count); set_pointer_mode(b_desc_, batch_count);
set_pointer_mode(c_desc_, batch_count); set_pointer_mode(c_desc_, batch_count);
@@ -159,30 +242,58 @@ void Matmul::run_batched(
// Launch kernel to set device offsets // Launch kernel to set device offsets
auto pointers = array( auto pointers = array(
allocator::malloc(batch_count * sizeof(uint64_t) * 4), allocator::malloc(batch_count * sizeof(uint64_t) * 4),
{static_cast<int>(batch_count * 4)}, {batch_count * 4},
uint64); uint64);
encoder.add_temporary(pointers); encoder.add_temporary(pointers);
int block_size = 512;
encoder.set_output_array(pointers); encoder.set_output_array(pointers);
encoder.add_kernel_node(
cu::set_addmm_device_pointers, int block_dims = std::min(batch_count, 256);
cuda::ceil_div(pointers.size(), block_size), int num_blocks = cuda::ceil_div(batch_count, block_dims);
block_size, int64_t batch_stride = M_ * N_;
0, int item_size = out.itemsize();
pointers.data<int8_t*>(),
a.data<int8_t>(), int ndim = batch_shape.size();
b.data<int8_t>(), if (ndim <= 3) {
c.data<int8_t>(), dispatch_1_2_3(ndim, [&](auto ndim_constant) {
out.data<int8_t>(), encoder.add_kernel_node(
static_cast<int>(out.dtype().size()), cu::set_addmm_device_pointers_nd<ndim_constant()>,
const_param(batch_shape), num_blocks,
const_param(a_batch_strides), block_dims,
const_param(b_batch_strides), 0,
const_param(c_batch_strides), pointers.data<int8_t*>(),
static_cast<int64_t>(M_) * N_, a.data<int8_t>(),
static_cast<int>(batch_shape.size()), b.data<int8_t>(),
batch_count); c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param<ndim_constant()>(batch_shape),
const_param<ndim_constant()>(a_batch_strides),
const_param<ndim_constant()>(b_batch_strides),
const_param<ndim_constant()>(c_batch_strides),
batch_stride,
batch_count);
});
} else {
encoder.add_kernel_node(
cu::set_addmm_device_pointers_g,
num_blocks,
block_dims,
0,
pointers.data<int8_t*>(),
a.data<int8_t>(),
b.data<int8_t>(),
c.data<int8_t>(),
out.data<int8_t>(),
item_size,
const_param(batch_shape),
const_param(a_batch_strides),
const_param(b_batch_strides),
const_param(c_batch_strides),
batch_stride,
ndim,
batch_count);
}
// Run matmul // Run matmul
encoder.set_input_array(pointers); encoder.set_input_array(pointers);

View File

@@ -2381,9 +2381,20 @@ array logsumexp(
throw std::invalid_argument( throw std::invalid_argument(
"[logsumexp] Received non-empty axes for array with 0 dimensions."); "[logsumexp] Received non-empty axes for array with 0 dimensions.");
} }
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating); bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 && if (!is_complex && reduce_last_dim) {
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
auto out_shape = a.shape(); auto out_shape = a.shape();
out_shape.back() = 1; out_shape.back() = 1;
@@ -3403,10 +3414,20 @@ array softmax(
throw std::invalid_argument( throw std::invalid_argument(
"[softmax] Received non-empty axes for array with 0 dimensions."); "[softmax] Received non-empty axes for array with 0 dimensions.");
} }
bool reduce_last_dim =
!axes.empty() && (axes.back() == a.ndim() - 1 || axes.back() == -1);
if (reduce_last_dim) {
// For more than 2 axes check if axes is [0, 1, ..., NDIM - 1] and shape
// is [1, 1, ..., N].
for (int i = axes.size() - 2; i >= 0; --i) {
if ((axes[i] + 1 != axes[i + 1]) || (a.shape(axes[i]) != 1)) {
reduce_last_dim = false;
break;
}
}
}
bool is_complex = issubdtype(a.dtype(), complexfloating); bool is_complex = issubdtype(a.dtype(), complexfloating);
if (!is_complex && axes.size() == 1 && if (!is_complex && reduce_last_dim) {
(a.ndim() == axes[0] + 1 || axes[0] == -1)) {
auto dtype = at_least_float(a.dtype()); auto dtype = at_least_float(a.dtype());
return array( return array(
a.shape(), a.shape(),