diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 321bd66b4..21cd677a8 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -171,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { kernel, num_blocks, block_dim(), + 0, in.data(), out.data(), out.size(), diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 752f4b54c..0243d4f41 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -222,6 +222,7 @@ void binary_op_gpu_inplace( dims_constant()>, num_blocks, block_dims, + 0, a.data(), b.data(), out.data(), @@ -236,6 +237,7 @@ void binary_op_gpu_inplace( cu::binary_g, num_blocks, block_dims, + 0, a.data(), b.data(), out.data(), @@ -264,6 +266,7 @@ void binary_op_gpu_inplace( kernel, num_blocks, block_dims, + 0, a.data(), b.data(), out.data(), diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index a56f00468..49a747829 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -238,6 +238,7 @@ void binary_two_op_gpu_inplace( dims_constant()>, num_blocks, block_dims, + 0, a.data(), b.data(), out_a.data(), @@ -254,6 +255,7 @@ void binary_two_op_gpu_inplace( cu::binary_two_g, num_blocks, block_dims, + 0, a.data(), b.data(), out_a.data(), @@ -287,6 +289,7 @@ void binary_two_op_gpu_inplace( kernel, num_blocks, block_dims, + 0, a.data(), b.data(), out_a.data(), diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 6eda2533f..feb3169c0 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -295,7 +295,7 @@ void Compiled::eval_gpu( auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(outputs[0], large, work_per_thread); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 3ec7478be..9c2aa9838 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -76,6 +76,7 @@ void copy_contiguous( kernel, num_blocks, block_dims, + 0, in.data() + in_offset, out.data() + out_offset, out.data_size()); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index b65a24e54..64c67a176 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -77,6 +77,7 @@ void copy_general( cu::copy_gg_nd, num_blocks, block_dims, + 0, in_ptr, out_ptr, data_size, @@ -91,6 +92,7 @@ void copy_general( cu::copy_gg, num_blocks, block_dims, + 0, in_ptr, out_ptr, data_size, diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index bafc82057..7a7f0dca5 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -83,6 +83,7 @@ void copy_general_dynamic( dims_constant()>, num_blocks, block_dims, + 0, in_ptr, out_ptr, out.size(), @@ -98,6 +99,7 @@ void copy_general_dynamic( cu::copy_gg_dynamic, num_blocks, block_dims, + 0, in_ptr, out_ptr, out.size(), diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index 052cf56c3..f381f14fa 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -68,6 +68,7 @@ void copy_general_input( cu::copy_g_nd, num_blocks, block_dims, + 0, in_ptr, out_ptr, out.size(), @@ -80,6 +81,7 @@ void copy_general_input( cu::copy_g, num_blocks, block_dims, + 0, in_ptr, out_ptr, out.size(), diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 5871ce3e2..3b1fd7ddb 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -224,12 +224,14 @@ void CommandEncoder::add_kernel_node( void* func, dim3 grid_dim, dim3 block_dim, + uint32_t smem_bytes, void** params) { cudaKernelNodeParams kernel_params = {0}; kernel_params.func = func; kernel_params.gridDim = grid_dim; kernel_params.blockDim = block_dim; kernel_params.kernelParams = params; + kernel_params.sharedMemBytes = smem_bytes; add_kernel_node(kernel_params); } @@ -237,6 +239,7 @@ void CommandEncoder::add_kernel_node( CUfunction func, dim3 grid_dim, dim3 block_dim, + uint32_t smem_bytes, void** params) { CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; kernel_params.func = func; @@ -247,6 +250,7 @@ void CommandEncoder::add_kernel_node( kernel_params.blockDimY = block_dim.y; kernel_params.blockDimZ = block_dim.z; kernel_params.kernelParams = params; + kernel_params.sharedMemBytes = smem_bytes; add_kernel_node(kernel_params); } diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 25c26fb0d..ea932082c 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -47,25 +47,34 @@ class CommandEncoder { void set_output_array(const array& arr); template - void - add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { + void add_kernel_node( + F* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + Params&&... params) { constexpr size_t num = sizeof...(Params); void* ptrs[num]; size_t i = 0; ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( std::forward(params)), ...); - add_kernel_node((void*)func, grid_dim, block_dim, ptrs); + add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs); } void add_kernel_node( CUfunction func, dim3 grid_dim, dim3 block_dim, + uint32_t smem_bytes, void** params); - void - add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); + void add_kernel_node( + void* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + void** params); // Low-level graph helpers. void add_kernel_node(const cudaKernelNodeParams& params); diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu index 4e72fdc64..86733fb06 100644 --- a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu @@ -108,6 +108,7 @@ void Matmul::run_batched( cu::set_mm_device_pointers, cuda::ceil_div(pointers.size(), block_size), block_size, + 0, pointers.data(), a.data(), b.data(), @@ -168,6 +169,7 @@ void Matmul::run_batched( cu::set_addmm_device_pointers, cuda::ceil_div(pointers.size(), block_size), block_size, + 0, pointers.data(), a.data(), b.data(), diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 69a85f6ac..22cff87d7 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(out, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(upd, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(idx, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(idx, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index d0d0f80c8..369b2547e 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -279,6 +279,7 @@ void LayerNorm::eval_gpu( kernel, n_rows, block_dim(), + 0, x.data(), w.data(), b.data(), @@ -391,6 +392,7 @@ void LayerNormVJP::eval_gpu( kernel, n_rows, block_dim(), + 0, x.data(), w.data(), g.data(), diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index 2afcc7e70..b90c300d0 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -150,6 +150,7 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { kernel, n_rows, block_dim(), + 0, in.data(), out.data(), axis_size); diff --git a/mlx/backend/cuda/quantized/affine_quantize.cu b/mlx/backend/cuda/quantized/affine_quantize.cu index 55322fa3e..94e67d135 100644 --- a/mlx/backend/cuda/quantized/affine_quantize.cu +++ b/mlx/backend/cuda/quantized/affine_quantize.cu @@ -261,6 +261,7 @@ void affine_quantize( kernel, num_blocks, block_dims, + 0, w.data(), wq.data(), scales.data(), @@ -316,6 +317,7 @@ void affine_dequantize( kernel, num_blocks, block_dims, + 0, wq.data(), scales.data(), biases.data(), diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 7221af356..26a3eb8b7 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { cu::rbitsc, grid, block, + 0, keys.data(), out.data(), grid_dims, @@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { cu::rbits, grid, block, + 0, keys.data(), out.data(), grid_dims, diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 166a11a79..b815597bd 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -120,6 +120,7 @@ void all_reduce( kernel, blocks, threads, + 0, static_cast(indata), intermediate.data(), block_step, @@ -146,6 +147,7 @@ void all_reduce( kernel, blocks, threads, + 0, static_cast(indata), out.data(), block_step, diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index fec5ca76b..04c400c47 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -230,7 +230,7 @@ void col_reduce_looped( auto kernel = cu::col_reduce_looped; encoder.add_kernel_node( - kernel, grid, blocks, indata, out.data(), args); + kernel, grid, blocks, 0, indata, out.data(), args); }); }); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 649d80190..8c0d380f5 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -41,7 +41,8 @@ void init_reduce( dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); grid.x = (grid.x + 1023) / 1024; - encoder.add_kernel_node(kernel, grid, block, out.data(), out.size()); + encoder.add_kernel_node( + kernel, grid, block, 0, out.data(), out.size()); }); }); } diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 61838ddd3..35f2287d6 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -269,7 +269,7 @@ void row_reduce_simple( int size = plan.shape.back(); encoder.add_kernel_node( - kernel, grid, block, indata, out.data(), out.size(), size); + kernel, grid, block, 0, indata, out.data(), out.size(), size); }); }); } @@ -322,7 +322,7 @@ void row_reduce_looped( }); encoder.add_kernel_node( - kernel, grid, block, indata, out.data(), out.size(), args); + kernel, grid, block, 0, indata, out.data(), out.size(), args); }); }); } diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 419f3d217..bc879c6f8 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -222,6 +222,7 @@ void RMSNorm::eval_gpu( kernel, n_rows, block_dim(), + 0, x.data(), w.data(), out.data(), @@ -316,6 +317,7 @@ void RMSNormVJP::eval_gpu( kernel, n_rows, block_dim(), + 0, x.data(), w.data(), g.data(), diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 517cddfe0..1c00f7a33 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -325,6 +325,7 @@ void RoPE::eval_gpu( kernel, grid, block, + 0, (donated ? out : in).data(), out.data(), offset.data(), @@ -341,6 +342,7 @@ void RoPE::eval_gpu( kernel, grid, block, + 0, (donated ? out : in).data(), out.data(), offset.data(), @@ -360,6 +362,7 @@ void RoPE::eval_gpu( kernel, grid, block, + 0, (donated ? out : in).data(), out.data(), offset.data(), @@ -381,6 +384,7 @@ void RoPE::eval_gpu( kernel, grid, block, + 0, (donated ? out : in).data(), out.data(), offset.data(), diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 969264e34..56d4ae275 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -414,6 +414,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { kernel, in.data_size() / axis_size, block_dim, + 0, in.data(), out.data(), axis_size); @@ -443,6 +444,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { kernel, num_blocks, block_dim, + 0, in.data(), out.data(), axis_size, diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 2ff4464b0..d808bce38 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -151,6 +151,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { kernel, n_rows, block_dim(), + 0, in.data(), out.data(), axis_size); diff --git a/mlx/backend/cuda/steel/tiles.cuh b/mlx/backend/cuda/steel/tiles.cuh index a44113e6b..be6c46648 100644 --- a/mlx/backend/cuda/steel/tiles.cuh +++ b/mlx/backend/cuda/steel/tiles.cuh @@ -2,7 +2,7 @@ #pragma once -#include "mlx/backend/steel/utils.cuh" +#include "mlx/backend/cuda/steel/utils.cuh" namespace mlx::core::cu { @@ -223,6 +223,57 @@ struct RegisterTile { } }; +/** + * A simple container of multiple Tile16x16. + * + * Provides utility functions for loading and manipulating collections of basic + * tiles. + */ +template +struct RegisterTile { + static constexpr int ROWS = ROWS_; + static constexpr int COLS = COLS_; + static constexpr int TILES_X = COLS / 16; + static constexpr int TILES_Y = ROWS / 16; + + Tile16x16 data[TILES_X * TILES_Y]; + + __device__ inline void fill(T v) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].fill(v); + } + } + } + + template + __device__ inline void + load(Tile& tile, uint32_t base_address, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].load( + tile.loc(base_address, row + i * 16, col + j * 16)); + } + } + } + + template + __device__ inline void store_global(U* x, int N, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].store_global( + x + (row + i * 16) * N + col + j * 16, N); + } + } + } +}; + template struct SharedTile { static constexpr int ROWS = ROWS_; diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 58d3fa119..93a0839d5 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -130,6 +130,7 @@ void ternary_op_gpu_inplace( cu::ternary_g_nd, num_blocks, block_dims, + 0, a.data(), b.data(), c.data(), @@ -146,6 +147,7 @@ void ternary_op_gpu_inplace( cu::ternary_g, num_blocks, block_dims, + 0, a.data(), b.data(), c.data(), @@ -168,6 +170,7 @@ void ternary_op_gpu_inplace( cu::ternary_v, num_blocks, block_dims, + 0, a.data(), b.data(), c.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 9c8db9a89..96888da97 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -135,6 +135,7 @@ void unary_op_gpu_inplace( cu::unary_v, num_blocks, block_dims, + 0, in.data(), out.data(), out.data_size()); @@ -146,6 +147,7 @@ void unary_op_gpu_inplace( cu::unary_g, num_blocks, block_dims, + 0, in.data(), out.data(), out.data_size(),