Add dynamic shared memory

This commit is contained in:
Angelos Katharopoulos
2025-07-22 23:36:53 -07:00
parent 1523b803f3
commit c456d59e9f
27 changed files with 119 additions and 15 deletions

View File

@@ -171,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
num_blocks, num_blocks,
block_dim(), block_dim(),
0,
in.data<T>(), in.data<T>(),
out.data<uint32_t>(), out.data<uint32_t>(),
out.size(), out.size(),

View File

@@ -222,6 +222,7 @@ void binary_op_gpu_inplace(
dims_constant()>, dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
@@ -236,6 +237,7 @@ void binary_op_gpu_inplace(
cu::binary_g<Op, InType, OutType, IdxT>, cu::binary_g<Op, InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
@@ -264,6 +266,7 @@ void binary_op_gpu_inplace(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out.data<OutType>(), out.data<OutType>(),

View File

@@ -238,6 +238,7 @@ void binary_two_op_gpu_inplace(
dims_constant()>, dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),
@@ -254,6 +255,7 @@ void binary_two_op_gpu_inplace(
cu::binary_two_g<Op, InType, OutType, IdxT>, cu::binary_two_g<Op, InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),
@@ -287,6 +289,7 @@ void binary_two_op_gpu_inplace(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),
out_a.data<OutType>(), out_a.data<OutType>(),

View File

@@ -295,7 +295,7 @@ void Compiled::eval_gpu(
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = auto [num_blocks, block_dims] =
get_launch_args(outputs[0], large, work_per_thread); get_launch_args(outputs[0], large, work_per_thread);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -76,6 +76,7 @@ void copy_contiguous(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in.data<InType>() + in_offset, in.data<InType>() + in_offset,
out.data<OutType>() + out_offset, out.data<OutType>() + out_offset,
out.data_size()); out.data_size());

View File

@@ -77,6 +77,7 @@ void copy_general(
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>, cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
data_size, data_size,
@@ -91,6 +92,7 @@ void copy_general(
cu::copy_gg<InType, OutType, IdxT>, cu::copy_gg<InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
data_size, data_size,

View File

@@ -83,6 +83,7 @@ void copy_general_dynamic(
dims_constant()>, dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),
@@ -98,6 +99,7 @@ void copy_general_dynamic(
cu::copy_gg_dynamic<InType, OutType, IdxT>, cu::copy_gg_dynamic<InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),

View File

@@ -68,6 +68,7 @@ void copy_general_input(
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>, cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),
@@ -80,6 +81,7 @@ void copy_general_input(
cu::copy_g<InType, OutType, IdxT>, cu::copy_g<InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in_ptr, in_ptr,
out_ptr, out_ptr,
out.size(), out.size(),

View File

@@ -224,12 +224,14 @@ void CommandEncoder::add_kernel_node(
void* func, void* func,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes,
void** params) { void** params) {
cudaKernelNodeParams kernel_params = {0}; cudaKernelNodeParams kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
kernel_params.gridDim = grid_dim; kernel_params.gridDim = grid_dim;
kernel_params.blockDim = block_dim; kernel_params.blockDim = block_dim;
kernel_params.kernelParams = params; kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params); add_kernel_node(kernel_params);
} }
@@ -237,6 +239,7 @@ void CommandEncoder::add_kernel_node(
CUfunction func, CUfunction func,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes,
void** params) { void** params) {
CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; CUDA_KERNEL_NODE_PARAMS kernel_params = {0};
kernel_params.func = func; kernel_params.func = func;
@@ -247,6 +250,7 @@ void CommandEncoder::add_kernel_node(
kernel_params.blockDimY = block_dim.y; kernel_params.blockDimY = block_dim.y;
kernel_params.blockDimZ = block_dim.z; kernel_params.blockDimZ = block_dim.z;
kernel_params.kernelParams = params; kernel_params.kernelParams = params;
kernel_params.sharedMemBytes = smem_bytes;
add_kernel_node(kernel_params); add_kernel_node(kernel_params);
} }

View File

@@ -47,25 +47,34 @@ class CommandEncoder {
void set_output_array(const array& arr); void set_output_array(const array& arr);
template <typename F, typename... Params> template <typename F, typename... Params>
void void add_kernel_node(
add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { F* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
Params&&... params) {
constexpr size_t num = sizeof...(Params); constexpr size_t num = sizeof...(Params);
void* ptrs[num]; void* ptrs[num];
size_t i = 0; size_t i = 0;
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }( ([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)), std::forward<Params>(params)),
...); ...);
add_kernel_node((void*)func, grid_dim, block_dim, ptrs); add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs);
} }
void add_kernel_node( void add_kernel_node(
CUfunction func, CUfunction func,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
uint32_t smem_bytes,
void** params); void** params);
void void add_kernel_node(
add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); void* func,
dim3 grid_dim,
dim3 block_dim,
uint32_t smem_bytes,
void** params);
// Low-level graph helpers. // Low-level graph helpers.
void add_kernel_node(const cudaKernelNodeParams& params); void add_kernel_node(const cudaKernelNodeParams& params);

View File

@@ -108,6 +108,7 @@ void Matmul::run_batched(
cu::set_mm_device_pointers, cu::set_mm_device_pointers,
cuda::ceil_div(pointers.size(), block_size), cuda::ceil_div(pointers.size(), block_size),
block_size, block_size,
0,
pointers.data<int8_t*>(), pointers.data<int8_t*>(),
a.data<int8_t>(), a.data<int8_t>(),
b.data<int8_t>(), b.data<int8_t>(),
@@ -168,6 +169,7 @@ void Matmul::run_batched(
cu::set_addmm_device_pointers, cu::set_addmm_device_pointers,
cuda::ceil_div(pointers.size(), block_size), cuda::ceil_div(pointers.size(), block_size),
block_size, block_size,
0,
pointers.data<int8_t*>(), pointers.data<int8_t*>(),
a.data<int8_t>(), a.data<int8_t>(),
b.data<int8_t>(), b.data<int8_t>(),

View File

@@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(out, large); auto [num_blocks, block_dims] = get_launch_args(out, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) { void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(upd, large); auto [num_blocks, block_dims] = get_launch_args(upd, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(idx, large); auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) { void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
auto kernel = mod.get_kernel(kernel_name); auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(idx, large); auto [num_blocks, block_dims] = get_launch_args(idx, large);
encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args());
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -279,6 +279,7 @@ void LayerNorm::eval_gpu(
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
x.data<DataType>(), x.data<DataType>(),
w.data<DataType>(), w.data<DataType>(),
b.data<DataType>(), b.data<DataType>(),
@@ -391,6 +392,7 @@ void LayerNormVJP::eval_gpu(
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
x.data<DataType>(), x.data<DataType>(),
w.data<DataType>(), w.data<DataType>(),
g.data<DataType>(), g.data<DataType>(),

View File

@@ -150,6 +150,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
in.data<DataType>(), in.data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
axis_size); axis_size);

View File

@@ -261,6 +261,7 @@ void affine_quantize(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
0,
w.data<T>(), w.data<T>(),
wq.data<uint8_t>(), wq.data<uint8_t>(),
scales.data<T>(), scales.data<T>(),
@@ -316,6 +317,7 @@ void affine_dequantize(
kernel, kernel,
num_blocks, num_blocks,
block_dims, block_dims,
0,
wq.data<uint8_t>(), wq.data<uint8_t>(),
scales.data<T>(), scales.data<T>(),
biases.data<T>(), biases.data<T>(),

View File

@@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbitsc, cu::rbitsc,
grid, grid,
block, block,
0,
keys.data<uint32_t>(), keys.data<uint32_t>(),
out.data<uint8_t>(), out.data<uint8_t>(),
grid_dims, grid_dims,
@@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector<array>& inputs, array& out) {
cu::rbits, cu::rbits,
grid, grid,
block, block,
0,
keys.data<uint32_t>(), keys.data<uint32_t>(),
out.data<uint8_t>(), out.data<uint8_t>(),
grid_dims, grid_dims,

View File

@@ -120,6 +120,7 @@ void all_reduce(
kernel, kernel,
blocks, blocks,
threads, threads,
0,
static_cast<T*>(indata), static_cast<T*>(indata),
intermediate.data<U>(), intermediate.data<U>(),
block_step, block_step,
@@ -146,6 +147,7 @@ void all_reduce(
kernel, kernel,
blocks, blocks,
threads, threads,
0,
static_cast<T*>(indata), static_cast<T*>(indata),
out.data<U>(), out.data<U>(),
block_step, block_step,

View File

@@ -230,7 +230,7 @@ void col_reduce_looped(
auto kernel = auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>; cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, blocks, indata, out.data<U>(), args); kernel, grid, blocks, 0, indata, out.data<U>(), args);
}); });
}); });
}); });

View File

@@ -41,7 +41,8 @@ void init_reduce(
dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
grid.x = (grid.x + 1023) / 1024; grid.x = (grid.x + 1023) / 1024;
encoder.add_kernel_node(kernel, grid, block, out.data<U>(), out.size()); encoder.add_kernel_node(
kernel, grid, block, 0, out.data<U>(), out.size());
}); });
}); });
} }

View File

@@ -269,7 +269,7 @@ void row_reduce_simple(
int size = plan.shape.back(); int size = plan.shape.back();
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, block, indata, out.data<U>(), out.size(), size); kernel, grid, block, 0, indata, out.data<U>(), out.size(), size);
}); });
}); });
} }
@@ -322,7 +322,7 @@ void row_reduce_looped(
}); });
encoder.add_kernel_node( encoder.add_kernel_node(
kernel, grid, block, indata, out.data<U>(), out.size(), args); kernel, grid, block, 0, indata, out.data<U>(), out.size(), args);
}); });
}); });
} }

View File

@@ -222,6 +222,7 @@ void RMSNorm::eval_gpu(
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
x.data<DataType>(), x.data<DataType>(),
w.data<DataType>(), w.data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
@@ -316,6 +317,7 @@ void RMSNormVJP::eval_gpu(
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
x.data<DataType>(), x.data<DataType>(),
w.data<DataType>(), w.data<DataType>(),
g.data<DataType>(), g.data<DataType>(),

View File

@@ -325,6 +325,7 @@ void RoPE::eval_gpu(
kernel, kernel,
grid, grid,
block, block,
0,
(donated ? out : in).data<DataType>(), (donated ? out : in).data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
offset.data<int32_t>(), offset.data<int32_t>(),
@@ -341,6 +342,7 @@ void RoPE::eval_gpu(
kernel, kernel,
grid, grid,
block, block,
0,
(donated ? out : in).data<DataType>(), (donated ? out : in).data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
offset.data<int32_t>(), offset.data<int32_t>(),
@@ -360,6 +362,7 @@ void RoPE::eval_gpu(
kernel, kernel,
grid, grid,
block, block,
0,
(donated ? out : in).data<DataType>(), (donated ? out : in).data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
offset.data<int32_t>(), offset.data<int32_t>(),
@@ -381,6 +384,7 @@ void RoPE::eval_gpu(
kernel, kernel,
grid, grid,
block, block,
0,
(donated ? out : in).data<DataType>(), (donated ? out : in).data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
offset.data<int32_t>(), offset.data<int32_t>(),

View File

@@ -414,6 +414,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
in.data_size() / axis_size, in.data_size() / axis_size,
block_dim, block_dim,
0,
in.data<T>(), in.data<T>(),
out.data<U>(), out.data<U>(),
axis_size); axis_size);
@@ -443,6 +444,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
num_blocks, num_blocks,
block_dim, block_dim,
0,
in.data<T>(), in.data<T>(),
out.data<U>(), out.data<U>(),
axis_size, axis_size,

View File

@@ -151,6 +151,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, kernel,
n_rows, n_rows,
block_dim(), block_dim(),
0,
in.data<DataType>(), in.data<DataType>(),
out.data<DataType>(), out.data<DataType>(),
axis_size); axis_size);

View File

@@ -2,7 +2,7 @@
#pragma once #pragma once
#include "mlx/backend/steel/utils.cuh" #include "mlx/backend/cuda/steel/utils.cuh"
namespace mlx::core::cu { namespace mlx::core::cu {
@@ -223,6 +223,57 @@ struct RegisterTile {
} }
}; };
/**
* A simple container of multiple Tile16x16.
*
* Provides utility functions for loading and manipulating collections of basic
* tiles.
*/
template <typename T, int ROWS_, int COLS_>
struct RegisterTile {
static constexpr int ROWS = ROWS_;
static constexpr int COLS = COLS_;
static constexpr int TILES_X = COLS / 16;
static constexpr int TILES_Y = ROWS / 16;
Tile16x16<T> data[TILES_X * TILES_Y];
__device__ inline void fill(T v) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].fill(v);
}
}
}
template <typename Tile>
__device__ inline void
load(Tile& tile, uint32_t base_address, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].load(
tile.loc(base_address, row + i * 16, col + j * 16));
}
}
}
template <typename U>
__device__ inline void store_global(U* x, int N, int row, int col) {
MLX_UNROLL
for (int i = 0; i < TILES_Y; i++) {
MLX_UNROLL
for (int j = 0; j < TILES_X; j++) {
data[i * TILES_X + j].store_global(
x + (row + i * 16) * N + col + j * 16, N);
}
}
}
};
template <typename T, int ROWS_, int COLS_> template <typename T, int ROWS_, int COLS_>
struct SharedTile { struct SharedTile {
static constexpr int ROWS = ROWS_; static constexpr int ROWS = ROWS_;

View File

@@ -130,6 +130,7 @@ void ternary_op_gpu_inplace(
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>, cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<bool>(), a.data<bool>(),
b.data<DType>(), b.data<DType>(),
c.data<DType>(), c.data<DType>(),
@@ -146,6 +147,7 @@ void ternary_op_gpu_inplace(
cu::ternary_g<Op, DType, IdxT>, cu::ternary_g<Op, DType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<bool>(), a.data<bool>(),
b.data<DType>(), b.data<DType>(),
c.data<DType>(), c.data<DType>(),
@@ -168,6 +170,7 @@ void ternary_op_gpu_inplace(
cu::ternary_v<Op, DType, IdxT, N_READS>, cu::ternary_v<Op, DType, IdxT, N_READS>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
a.data<bool>(), a.data<bool>(),
b.data<DType>(), b.data<DType>(),
c.data<DType>(), c.data<DType>(),

View File

@@ -135,6 +135,7 @@ void unary_op_gpu_inplace(
cu::unary_v<Op, InType, OutType, IdxT, N_READS>, cu::unary_v<Op, InType, OutType, IdxT, N_READS>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in.data<InType>(), in.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
out.data_size()); out.data_size());
@@ -146,6 +147,7 @@ void unary_op_gpu_inplace(
cu::unary_g<Op, InType, OutType, IdxT>, cu::unary_g<Op, InType, OutType, IdxT>,
num_blocks, num_blocks,
block_dims, block_dims,
0,
in.data<InType>(), in.data<InType>(),
out.data<OutType>(), out.data<OutType>(),
out.data_size(), out.data_size(),