diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 321bd66b47..21cd677a89 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -171,6 +171,7 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { kernel, num_blocks, block_dim(), + 0, in.data(), out.data(), out.size(), diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 752f4b54c5..0243d4f414 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -222,6 +222,7 @@ void binary_op_gpu_inplace( dims_constant()>, num_blocks, block_dims, + 0, a.data(), b.data(), out.data(), @@ -236,6 +237,7 @@ void binary_op_gpu_inplace( cu::binary_g, num_blocks, block_dims, + 0, a.data(), b.data(), out.data(), @@ -264,6 +266,7 @@ void binary_op_gpu_inplace( kernel, num_blocks, block_dims, + 0, a.data(), b.data(), out.data(), diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index a56f004687..49a7478296 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -238,6 +238,7 @@ void binary_two_op_gpu_inplace( dims_constant()>, num_blocks, block_dims, + 0, a.data(), b.data(), out_a.data(), @@ -254,6 +255,7 @@ void binary_two_op_gpu_inplace( cu::binary_two_g, num_blocks, block_dims, + 0, a.data(), b.data(), out_a.data(), @@ -287,6 +289,7 @@ void binary_two_op_gpu_inplace( kernel, num_blocks, block_dims, + 0, a.data(), b.data(), out_a.data(), diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 7f859b91a0..9e63a269b6 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -334,7 +334,7 @@ void Compiled::eval_gpu( auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(outputs[0], large, work_per_thread); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 3ec7478be6..9c2aa9838b 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -76,6 +76,7 @@ void copy_contiguous( kernel, num_blocks, block_dims, + 0, in.data() + in_offset, out.data() + out_offset, out.data_size()); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index b65a24e547..64c67a176a 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -77,6 +77,7 @@ void copy_general( cu::copy_gg_nd, num_blocks, block_dims, + 0, in_ptr, out_ptr, data_size, @@ -91,6 +92,7 @@ void copy_general( cu::copy_gg, num_blocks, block_dims, + 0, in_ptr, out_ptr, data_size, diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index bafc82057f..7a7f0dca58 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -83,6 +83,7 @@ void copy_general_dynamic( dims_constant()>, num_blocks, block_dims, + 0, in_ptr, out_ptr, out.size(), @@ -98,6 +99,7 @@ void copy_general_dynamic( cu::copy_gg_dynamic, num_blocks, block_dims, + 0, in_ptr, out_ptr, out.size(), diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index 052cf56c33..f381f14fa0 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -68,6 +68,7 @@ void copy_general_input( cu::copy_g_nd, num_blocks, block_dims, + 0, in_ptr, out_ptr, out.size(), @@ -80,6 +81,7 @@ void copy_general_input( cu::copy_g, num_blocks, block_dims, + 0, in_ptr, out_ptr, out.size(), diff --git a/mlx/backend/cuda/device.cpp b/mlx/backend/cuda/device.cpp index 59357febf0..7a454e7d7f 100644 --- a/mlx/backend/cuda/device.cpp +++ b/mlx/backend/cuda/device.cpp @@ -224,12 +224,14 @@ void CommandEncoder::add_kernel_node( void* func, dim3 grid_dim, dim3 block_dim, + uint32_t smem_bytes, void** params) { cudaKernelNodeParams kernel_params = {0}; kernel_params.func = func; kernel_params.gridDim = grid_dim; kernel_params.blockDim = block_dim; kernel_params.kernelParams = params; + kernel_params.sharedMemBytes = smem_bytes; add_kernel_node(kernel_params); } @@ -237,6 +239,7 @@ void CommandEncoder::add_kernel_node( CUfunction func, dim3 grid_dim, dim3 block_dim, + uint32_t smem_bytes, void** params) { CUDA_KERNEL_NODE_PARAMS kernel_params = {0}; kernel_params.func = func; @@ -247,6 +250,7 @@ void CommandEncoder::add_kernel_node( kernel_params.blockDimY = block_dim.y; kernel_params.blockDimZ = block_dim.z; kernel_params.kernelParams = params; + kernel_params.sharedMemBytes = smem_bytes; add_kernel_node(kernel_params); } diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 25c26fb0d2..ea932082c5 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -47,25 +47,34 @@ class CommandEncoder { void set_output_array(const array& arr); template - void - add_kernel_node(F* func, dim3 grid_dim, dim3 block_dim, Params&&... params) { + void add_kernel_node( + F* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + Params&&... params) { constexpr size_t num = sizeof...(Params); void* ptrs[num]; size_t i = 0; ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( std::forward(params)), ...); - add_kernel_node((void*)func, grid_dim, block_dim, ptrs); + add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs); } void add_kernel_node( CUfunction func, dim3 grid_dim, dim3 block_dim, + uint32_t smem_bytes, void** params); - void - add_kernel_node(void* func, dim3 grid_dim, dim3 block_dim, void** params); + void add_kernel_node( + void* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + void** params); // Low-level graph helpers. void add_kernel_node(const cudaKernelNodeParams& params); diff --git a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu index 4e72fdc64b..86733fb06f 100644 --- a/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu +++ b/mlx/backend/cuda/gemms/cublas_batched_gemm_12_9.cu @@ -108,6 +108,7 @@ void Matmul::run_batched( cu::set_mm_device_pointers, cuda::ceil_div(pointers.size(), block_size), block_size, + 0, pointers.data(), a.data(), b.data(), @@ -168,6 +169,7 @@ void Matmul::run_batched( cu::set_addmm_device_pointers, cuda::ceil_div(pointers.size(), block_size), block_size, + 0, pointers.data(), a.data(), b.data(), diff --git a/mlx/backend/cuda/gemms/gemv.cu b/mlx/backend/cuda/gemms/gemv.cu index 55333adea3..552ab9cda6 100644 --- a/mlx/backend/cuda/gemms/gemv.cu +++ b/mlx/backend/cuda/gemms/gemv.cu @@ -143,6 +143,7 @@ void gemv( kernel, num_blocks_x, block_dims, + 0, mat, vec, out.data(), @@ -154,6 +155,7 @@ void gemv( kernel, dim3{num_blocks_x, batch_count}, block_dims, + 0, mat, vec, out.data(), diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 69a85f6acb..22cff87d7f 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -129,7 +129,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(out, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } void Scatter::eval_gpu(const std::vector& inputs, array& out) { @@ -230,7 +230,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(upd, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -318,7 +318,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(idx, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { @@ -422,7 +422,7 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(idx, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, args.args()); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } } // namespace mlx::core diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index d0d0f80c82..369b2547e2 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -279,6 +279,7 @@ void LayerNorm::eval_gpu( kernel, n_rows, block_dim(), + 0, x.data(), w.data(), b.data(), @@ -391,6 +392,7 @@ void LayerNormVJP::eval_gpu( kernel, n_rows, block_dim(), + 0, x.data(), w.data(), g.data(), diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index 2afcc7e705..b90c300d03 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -150,6 +150,7 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { kernel, n_rows, block_dim(), + 0, in.data(), out.data(), axis_size); diff --git a/mlx/backend/cuda/quantized/affine_quantize.cu b/mlx/backend/cuda/quantized/affine_quantize.cu index 55322fa3e8..94e67d135c 100644 --- a/mlx/backend/cuda/quantized/affine_quantize.cu +++ b/mlx/backend/cuda/quantized/affine_quantize.cu @@ -261,6 +261,7 @@ void affine_quantize( kernel, num_blocks, block_dims, + 0, w.data(), wq.data(), scales.data(), @@ -316,6 +317,7 @@ void affine_dequantize( kernel, num_blocks, block_dims, + 0, wq.data(), scales.data(), biases.data(), diff --git a/mlx/backend/cuda/random.cu b/mlx/backend/cuda/random.cu index 7221af3562..26a3eb8b71 100644 --- a/mlx/backend/cuda/random.cu +++ b/mlx/backend/cuda/random.cu @@ -170,6 +170,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { cu::rbitsc, grid, block, + 0, keys.data(), out.data(), grid_dims, @@ -180,6 +181,7 @@ void RandomBits::eval_gpu(const std::vector& inputs, array& out) { cu::rbits, grid, block, + 0, keys.data(), out.data(), grid_dims, diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 166a11a796..b815597bd4 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -120,6 +120,7 @@ void all_reduce( kernel, blocks, threads, + 0, static_cast(indata), intermediate.data(), block_step, @@ -146,6 +147,7 @@ void all_reduce( kernel, blocks, threads, + 0, static_cast(indata), out.data(), block_step, diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index fec5ca76b6..04c400c473 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -230,7 +230,7 @@ void col_reduce_looped( auto kernel = cu::col_reduce_looped; encoder.add_kernel_node( - kernel, grid, blocks, indata, out.data(), args); + kernel, grid, blocks, 0, indata, out.data(), args); }); }); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 649d801903..8c0d380f52 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -41,7 +41,8 @@ void init_reduce( dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); grid.x = (grid.x + 1023) / 1024; - encoder.add_kernel_node(kernel, grid, block, out.data(), out.size()); + encoder.add_kernel_node( + kernel, grid, block, 0, out.data(), out.size()); }); }); } diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 61838ddd3b..35f2287d6b 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -269,7 +269,7 @@ void row_reduce_simple( int size = plan.shape.back(); encoder.add_kernel_node( - kernel, grid, block, indata, out.data(), out.size(), size); + kernel, grid, block, 0, indata, out.data(), out.size(), size); }); }); } @@ -322,7 +322,7 @@ void row_reduce_looped( }); encoder.add_kernel_node( - kernel, grid, block, indata, out.data(), out.size(), args); + kernel, grid, block, 0, indata, out.data(), out.size(), args); }); }); } diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 419f3d2179..bc879c6f8e 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -222,6 +222,7 @@ void RMSNorm::eval_gpu( kernel, n_rows, block_dim(), + 0, x.data(), w.data(), out.data(), @@ -316,6 +317,7 @@ void RMSNormVJP::eval_gpu( kernel, n_rows, block_dim(), + 0, x.data(), w.data(), g.data(), diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 517cddfe05..1c00f7a334 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -325,6 +325,7 @@ void RoPE::eval_gpu( kernel, grid, block, + 0, (donated ? out : in).data(), out.data(), offset.data(), @@ -341,6 +342,7 @@ void RoPE::eval_gpu( kernel, grid, block, + 0, (donated ? out : in).data(), out.data(), offset.data(), @@ -360,6 +362,7 @@ void RoPE::eval_gpu( kernel, grid, block, + 0, (donated ? out : in).data(), out.data(), offset.data(), @@ -381,6 +384,7 @@ void RoPE::eval_gpu( kernel, grid, block, + 0, (donated ? out : in).data(), out.data(), offset.data(), diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index 969264e34c..56d4ae2754 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -414,6 +414,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { kernel, in.data_size() / axis_size, block_dim, + 0, in.data(), out.data(), axis_size); @@ -443,6 +444,7 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { kernel, num_blocks, block_dim, + 0, in.data(), out.data(), axis_size, diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 2ff4464b01..d808bce382 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -151,6 +151,7 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { kernel, n_rows, block_dim(), + 0, in.data(), out.data(), axis_size); diff --git a/mlx/backend/cuda/steel/defines.cuh b/mlx/backend/cuda/steel/defines.cuh new file mode 100644 index 0000000000..bf920428fb --- /dev/null +++ b/mlx/backend/cuda/steel/defines.cuh @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#define MLX_UNROLL _Pragma("unroll") + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +#define MLX_CUDA_SM_80_ENABLED +#endif diff --git a/mlx/backend/cuda/steel/gemm.cuh b/mlx/backend/cuda/steel/gemm.cuh new file mode 100644 index 0000000000..99580d2def --- /dev/null +++ b/mlx/backend/cuda/steel/gemm.cuh @@ -0,0 +1,101 @@ + +#include "mlx/backend/cuda/steel/mma.cuh" +#include "mlx/backend/cuda/steel/tiles.cuh" + +namespace mlx::core::cu { + +/** + * An example gemm written with the utils. + * + * Computes A @ B.T when A and B are all aligned with the block sizes. + */ +template +__global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) { + constexpr int WARPS_M = 2; + constexpr int WARPS_N = 2; + constexpr int NUM_WARPS = WARPS_M * WARPS_N; + constexpr int WARP_STEP_M = BM / WARPS_M; + constexpr int WARP_STEP_N = BN / WARPS_N; + + // Precompute some offsets for each thread + const int warpid = threadIdx.x / 32; + const int laneid = threadIdx.x % 32; + const int wm = warpid / WARPS_N; + const int wn = warpid % WARPS_N; + const int offset_m = wm * WARP_STEP_M; + const int offset_n = wn * WARP_STEP_N; + + // Allocate shared memory + extern __shared__ char shmem[]; + SharedTile(&as)[2] = *(SharedTile(*)[2])(&shmem[0]); + SharedTile(&bs)[2] = + *(SharedTile(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]); + + // Allocate registers for the MMA + RegisterTile C; + RegisterTile A; + RegisterTile B; + + // Move the global pointers to the tile + a += blockIdx.y * BM * K; + b += blockIdx.x * BN * K; + y += blockIdx.y * BM * N + blockIdx.x * BN; + + // Zero the accumulators + C.fill(0); + + // Start the SM pipeline + load_async(as[0], as[0].base_addr(), a, K); + load_async(bs[0], bs[0].base_addr(), b, K); + cp_async_commit(); + + int tic = 0; + for (int k_block = BK; k_block < K; k_block += BK) { + load_async(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K); + load_async(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K); + cp_async_commit(); + cp_async_wait<1>(); + __syncthreads(); + + MLX_UNROLL + for (int k = 0; k < BK / 16; k++) { + A.load( + as[tic], + as[tic].base_addr(), + offset_m + laneid % 16, + k * 16 + laneid / 16 * 8); + B.load( + bs[tic], + bs[tic].base_addr(), + offset_n + laneid % 16, + k * 16 + laneid / 16 * 8); + + mma_t(C, A, B); + } + + tic ^= 1; + } + + // Empty the pipeline + cp_async_wait_all(); + __syncthreads(); + MLX_UNROLL + for (int k = 0; k < BK / 16; k++) { + A.load( + as[tic], + as[tic].base_addr(), + offset_m + laneid % 16, + k * 16 + laneid / 16 * 8); + B.load( + bs[tic], + bs[tic].base_addr(), + offset_n + laneid % 16, + k * 16 + laneid / 16 * 8); + + mma_t(C, A, B); + } + + C.store_global(y, N, offset_m, offset_n); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/steel/mma.cuh b/mlx/backend/cuda/steel/mma.cuh new file mode 100644 index 0000000000..94e314909b --- /dev/null +++ b/mlx/backend/cuda/steel/mma.cuh @@ -0,0 +1,117 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/steel/defines.cuh" +#include "mlx/backend/cuda/steel/tiles.cuh" + +namespace mlx::core::cu { + +/** + * Fallback mma. + * + * We should probably a) implement a fallback or complain about it to the + * compiler. + */ +template +__device__ inline void +mma_t(Tile16x16& C, Tile16x16& A, Tile16x16& B) {} + +/** + * Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16 + * float tile. + * + * We actually perform C += A @ B.T + */ +__device__ __forceinline__ void mma_t( + Tile16x16& C, + Tile16x16<__nv_bfloat16>& A, + Tile16x16<__nv_bfloat16>& B) { +#if defined(MLX_CUDA_SM_80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(C.values[0].x), + "+f"(C.values[0].y), + "+f"(C.values[1].x), + "+f"(C.values[1].y) + + // A matrix + : "r"(*(uint32_t*)(&A.values[0])), + "r"(*(uint32_t*)(&A.values[1])), + "r"(*(uint32_t*)(&A.values[2])), + "r"(*(uint32_t*)(&A.values[3])), + + // B matrix + "r"(*(uint32_t*)(&B.values[0])), + "r"(*(uint32_t*)(&B.values[2])), + + // C matrix + "f"(C.values[0].x), + "f"(C.values[0].y), + "f"(C.values[1].x), + "f"(C.values[1].y)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(C.values[2].x), + "+f"(C.values[2].y), + "+f"(C.values[3].x), + "+f"(C.values[3].y) + + // A matrix + : "r"(*(uint32_t*)(&A.values[0])), + "r"(*(uint32_t*)(&A.values[1])), + "r"(*(uint32_t*)(&A.values[2])), + "r"(*(uint32_t*)(&A.values[3])), + + // B matrix + "r"(*(uint32_t*)(&B.values[1])), + "r"(*(uint32_t*)(&B.values[3])), + + // C matrix + "f"(C.values[2].x), + "f"(C.values[2].y), + "f"(C.values[3].x), + "f"(C.values[3].y)); +#endif +} + +/** + * Multiply larger register tiles by delegating to mma_t. + */ +template +__device__ __forceinline__ void mma_t( + RegisterTile& C, + RegisterTile& A, + RegisterTile& B) { + constexpr int TILES_M = RegisterTile::TILES_Y; + constexpr int TILES_K = RegisterTile::TILES_X; + constexpr int TILES_N = RegisterTile::TILES_Y; + + MLX_UNROLL + for (int k = 0; k < TILES_K; k++) { + MLX_UNROLL + for (int m = 0; m < TILES_M; m++) { + MLX_UNROLL + for (int n = 0; n < TILES_N; n++) { + mma_t( + C.data[m * TILES_N + n], + A.data[m * TILES_K + k], + B.data[n * TILES_K + k]); + } + } + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/steel/tiles.cuh b/mlx/backend/cuda/steel/tiles.cuh new file mode 100644 index 0000000000..be6c46648b --- /dev/null +++ b/mlx/backend/cuda/steel/tiles.cuh @@ -0,0 +1,471 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/steel/utils.cuh" + +namespace mlx::core::cu { + +// Map types to their vector of 2 type float -> float2, double -> double2 etc +template +struct Vector2; +template <> +struct Vector2 { + using type = double2; +}; +template <> +struct Vector2 { + using type = float2; +}; +template <> +struct Vector2<__half> { + using type = __half2; +}; +template <> +struct Vector2<__nv_bfloat16> { + using type = __nv_bfloat162; +}; +template +using Vector2_t = typename Vector2::type; + +/** + * The basic building block for Ampere mmas. A 16x16 tile distributed across + * the warp. + * + * Each thread holds 8 values. They are distributed according to + * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float + * + * For use instructions see the individual methods eg load(). + */ +template +struct Tile16x16 { + using T2 = Vector2_t; + + T2 values[4]; + + __device__ inline void fill(T v) { + T2 v2 = {v, v}; + for (int i = 0; i < 4; i++) { + values[i] = v2; + } + } + + /** + * Load a 16x16 tile from shared memory. + * + * The instruction is a bit weird in the sense that the address provided by + * each thread and the elements loaded are not the same. + * + * We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a + * result the warp provides 4*8 = 32 addresses one per row. + * + * Threads 0-7 provide the addresses for the first tile, 8-15 for the second + * and so on. For instance to load a non swizzled tile we would do + * + * base_addr + (laneid % 16) * BK + (laneid / 2) * 8 + * + * See + * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix + */ + __device__ __forceinline__ void load(uint32_t row_address) { + if constexpr ( + std::is_same_v || std::is_same_v) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(*(uint32_t*)&(values[0])), + "=r"(*(uint32_t*)&(values[1])), + "=r"(*(uint32_t*)&(values[2])), + "=r"(*(uint32_t*)&(values[3])) + : "r"(row_address)); + } + } + + /** + * Store the tile to the address pointed to by `x`. + * + * The provided pointer is a generic pointer but this is meant to be used to + * store to global memory. For storing to shared memory we should use + * `stmatrix`. + * + * This also showcases the format of the tile quite nicely. Each register is + * holding to adjacent values. The indices are + * + * row + 0, col + 0 + * row + 8, col + 0 + * row + 0, col + 8 + * row + 8, col + 8 + * + * Given that we are dealing with Vector2_t the column offsets are 4 + * instead of 8. + */ + template + __device__ inline void store_global(U* x, int N) { + using U2 = Vector2_t; + U2* x2 = reinterpret_cast(x); + const int laneid = threadIdx.x % 32; + const int row = laneid / 4; + const int col = laneid % 4; + if constexpr (std::is_same_v) { + x2[(row + 0) * (N / 2) + col + 0] = values[0]; + x2[(row + 0) * (N / 2) + col + 4] = values[2]; + x2[(row + 8) * (N / 2) + col + 0] = values[1]; + x2[(row + 8) * (N / 2) + col + 4] = values[3]; + } else if constexpr ( + std::is_same_v && std::is_same_v) { + x2[(row + 0) * (N / 2) + col + 0] = + __floats2bfloat162_rn(values[0].x, values[0].y); + x2[(row + 0) * (N / 2) + col + 4] = + __floats2bfloat162_rn(values[2].x, values[2].y); + x2[(row + 8) * (N / 2) + col + 0] = + __floats2bfloat162_rn(values[1].x, values[1].y); + x2[(row + 8) * (N / 2) + col + 4] = + __floats2bfloat162_rn(values[3].x, values[3].y); + } + } + + template + __device__ inline void store_global_safe(U* x, int N, int max_rows) { + const int laneid = threadIdx.x % 32; + const int row = laneid / 4; + const int col = laneid % 4; + if (row < max_rows) { + x[(row + 0) * N + 2 * col + 0] = static_cast(values[0].x); + x[(row + 0) * N + 2 * col + 1] = static_cast(values[0].y); + x[(row + 0) * N + 2 * col + 8] = static_cast(values[2].x); + x[(row + 0) * N + 2 * col + 9] = static_cast(values[2].y); + } + if (row + 8 < max_rows) { + x[(row + 8) * N + 2 * col + 0] = static_cast(values[1].x); + x[(row + 8) * N + 2 * col + 1] = static_cast(values[1].y); + x[(row + 8) * N + 2 * col + 8] = static_cast(values[3].x); + x[(row + 8) * N + 2 * col + 9] = static_cast(values[3].y); + } + } +}; + +/** + * A simple container of multiple Tile16x16. + * + * Provides utility functions for loading and manipulating collections of basic + * tiles. + */ +template +struct RegisterTile { + static constexpr int ROWS = ROWS_; + static constexpr int COLS = COLS_; + static constexpr int TILES_X = COLS / 16; + static constexpr int TILES_Y = ROWS / 16; + + Tile16x16 data[TILES_X * TILES_Y]; + + __device__ inline void fill(T v) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].fill(v); + } + } + } + + template + __device__ __forceinline__ void + load(Tile& tile, uint32_t base_address, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].load( + tile.loc(base_address, row + i * 16, col + j * 16)); + } + } + } + + template + __device__ __forceinline__ void + load(Tile& tile, F f, uint32_t base_address, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + f(data[i * TILES_X + j], + tile, + base_address, + row + i * 16, + col + j * 16); + } + } + } + + template + __device__ inline void store_global(U* x, int N, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].store_global( + x + (row + i * 16) * N + col + j * 16, N); + } + } + } + + template + __device__ inline void + store_global_safe(U* x, int N, int row, int col, int max_rows) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].store_global_safe( + x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16); + } + } + } +}; + +/** + * A simple container of multiple Tile16x16. + * + * Provides utility functions for loading and manipulating collections of basic + * tiles. + */ +template +struct RegisterTile { + static constexpr int ROWS = ROWS_; + static constexpr int COLS = COLS_; + static constexpr int TILES_X = COLS / 16; + static constexpr int TILES_Y = ROWS / 16; + + Tile16x16 data[TILES_X * TILES_Y]; + + __device__ inline void fill(T v) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].fill(v); + } + } + } + + template + __device__ inline void + load(Tile& tile, uint32_t base_address, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].load( + tile.loc(base_address, row + i * 16, col + j * 16)); + } + } + } + + template + __device__ inline void store_global(U* x, int N, int row, int col) { + MLX_UNROLL + for (int i = 0; i < TILES_Y; i++) { + MLX_UNROLL + for (int j = 0; j < TILES_X; j++) { + data[i * TILES_X + j].store_global( + x + (row + i * 16) * N + col + j * 16, N); + } + } + } +}; + +template +struct SharedTile { + static constexpr int ROWS = ROWS_; + static constexpr int COLS = COLS_; + static constexpr int TILES_X = COLS / 16; + static constexpr int TILES_Y = ROWS / 16; + static constexpr int NUMEL = ROWS * COLS; + + // Swizzle taken from ThunderKittens. Should be changed when we switch to + // cute Layouts. + // + // See inludes/types/shared/st.cuh + // + // I do feel that it is too math heavy and can be improved. Also the math is + // done every time although the addresses don't change from load to load. I + // guess we are expecting the compiler to figure that out. + static constexpr int swizzle_bytes = + (sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32)) + : (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0)); + + T data[ROWS * COLS]; + + __device__ inline uint32_t base_addr() const { + return __cvta_generic_to_shared(&data[0]); + } + + // Return a pointer to the element at (row, col) using the swizzle. + __device__ static inline T* ptr(T* ptr, int row, int col) { + if constexpr (swizzle_bytes > 0) { + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = col / subtile_cols; + const uint64_t addr = + (uint64_t)(&ptr + [outer_idx * ROWS * subtile_cols + row * subtile_cols + + col % subtile_cols]); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (T*)(addr ^ swizzle); + } else { + return ptr + row * COLS + col; + } + } + + // Return the location of the element at (row, col) using the swizzle. + __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) { + if constexpr (swizzle_bytes > 0) { + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = col / subtile_cols; + const uint32_t addr = ptr + + sizeof(T) * + (outer_idx * ROWS * subtile_cols + row * subtile_cols + + col % subtile_cols); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (addr ^ swizzle); + } else { + return ptr + sizeof(T) * (row * COLS + col); + } + } + + // Convenience functions to edit elements going through the swizzle. + __device__ inline T& operator()(int row, int col) { + return *ptr(data, row, col); + } + __device__ inline void store(float4& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + __device__ inline void store(float2& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + __device__ inline void store(float& v, int row, int col) { + *(reinterpret_cast(ptr(data, row, col))) = v; + } + template + __device__ inline void store(T (&v)[N], int row, int col) { + if constexpr (sizeof(T) * N == 4) { + store(*(reinterpret_cast(&v[0])), row, col); + } else if constexpr (sizeof(T) * N == 8) { + store(*(reinterpret_cast(&v[0])), row, col); + } else if constexpr (sizeof(T) * N == 16) { + store(*(reinterpret_cast(&v[0])), row, col); + } else { + MLX_UNROLL + for (int i = 0; i < N; i++) { + *ptr(data, row, col + i) = v[i]; + } + } + } +}; + +/** + * Load the tile from global memory by loading 16 bytes at a time and storing + * them immediately. + * + * Can also be used as a fallback for architectures before sm_80. + */ +template +__device__ inline void load(Tile& tile, const T* x, int N) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x += row * N + col * ELEMENTS_PER_LOAD; + + MLX_UNROLL + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + float4 tmp; + tmp = *(reinterpret_cast(&x[i * STEP_ROWS * N])); + tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); + } +} + +/** + * The asynchronous equivalent of load. + * + * Loads the tile from global memory by submitting a bunch of async copy + * instructions. The copy won't start until commit is called and we don't have + * a guarantee it will finish until wait is called. + * + * It should be used as follows + * + * load(...) + * load(...) + * cp_async_commit() + * do_other_stuff() + * cp_async_wait_all() + * do_stuff_with_shmem() + */ +template +__device__ inline void +load_async(Tile& tile, uint32_t base_address, const T* x, int N) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x += row * N + col * ELEMENTS_PER_LOAD; + + MLX_UNROLL + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + cp_async<16>( + tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD), + x + i * STEP_ROWS * N); + } +} + +/** + * Same as load_async but checks if we can load the row. + * + * NOTE: It should be changed to use a predicated cp async instead. + */ +template +__device__ inline void load_async_safe( + Tile& tile, + uint32_t base_address, + const T* x, + int N, + int max_rows) { + constexpr int NUM_THREADS = NUM_WARPS * 32; + constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T); + constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD; + constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS; + constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD; + constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW; + + const int row = threadIdx.x / NUM_LOADS_PER_ROW; + const int col = threadIdx.x % NUM_LOADS_PER_ROW; + + x += row * N + col * ELEMENTS_PER_LOAD; + + MLX_UNROLL + for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) { + if (row + i * STEP_ROWS < max_rows) { + cp_async<16>( + tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD), + x + i * STEP_ROWS * N); + } else { + float4 tmp = {0, 0, 0, 0}; + tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD); + } + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/steel/utils.cuh b/mlx/backend/cuda/steel/utils.cuh new file mode 100644 index 0000000000..0957c09d0c --- /dev/null +++ b/mlx/backend/cuda/steel/utils.cuh @@ -0,0 +1,89 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/utils.cuh" +#include "mlx/backend/cuda/steel/defines.cuh" + +namespace mlx::core::cu { + +/** + * Copy bytes from the global memory address pointed to by x to the smem + * address pointed to by row_address. + * + * A simple wrapper over the PTX. + */ +template +__device__ inline void cp_async(uint32_t row_address, const T* x) { + static_assert( + N == 16 || N == 8 || N == 4, + "cp.async is only supported for N in {4, 8, 16}."); +#if defined(MLX_CUDA_SM_80_ENABLED) + if constexpr (N == 16) { + asm volatile( + "cp.async.ca.shared::cta.global [%0], [%1], 16;\n" ::"r"(row_address), + "l"(reinterpret_cast(x))); + } else if constexpr (N == 8) { + asm volatile( + "cp.async.ca.shared::cta.global [%0], [%1], 8;\n" ::"r"(row_address), + "l"(reinterpret_cast(x))); + } else if constexpr (N == 4) { + asm volatile( + "cp.async.ca.shared::cta.global [%0], [%1], 4;\n" ::"r"(row_address), + "l"(reinterpret_cast(x))); + } +#endif +} + +/** + * Submit all the previous async copies to be executed. + */ +__device__ inline void cp_async_commit() { +#if defined(MLX_CUDA_SM_80_ENABLED) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +/** + * Wait for all but N of the async copies to finish. + */ +template +__device__ inline void cp_async_wait() { +#if defined(MLX_CUDA_SM_80_ENABLED) + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + } +#endif +} + +/** + * Wait for all the async copies to finish. + */ +__device__ inline void cp_async_wait_all() { + cp_async_wait<0>(); +} + +/** + * Extract ``bits`` bits from the 32 bit value. + * + * Single instruction shift and mask. + */ +template +__device__ inline uint32_t extract_bits(uint32_t value, int start_bit) { + static_assert( + bits == 2 || bits == 4 || bits == 8, + "extract_bits only supports 2, 4, 8 for now."); + uint32_t result; + if constexpr (bits == 2) { + asm("bfe.u32 %0, %1, %2, 2;" : "=r"(result) : "r"(value), "r"(start_bit)); + } else if constexpr (bits == 4) { + asm("bfe.u32 %0, %1, %2, 4;" : "=r"(result) : "r"(value), "r"(start_bit)); + } else if constexpr (bits == 8) { + asm("bfe.u32 %0, %1, %2, 8;" : "=r"(result) : "r"(value), "r"(start_bit)); + } + return result; +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index bc4097d99c..cfc0e10b85 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -130,6 +130,7 @@ void ternary_op_gpu_inplace( cu::ternary_g_nd, num_blocks, block_dims, + 0, a.data(), b.data(), c.data(), @@ -146,6 +147,7 @@ void ternary_op_gpu_inplace( cu::ternary_g, num_blocks, block_dims, + 0, a.data(), b.data(), c.data(), @@ -168,6 +170,7 @@ void ternary_op_gpu_inplace( cu::ternary_v, num_blocks, block_dims, + 0, a.data(), b.data(), c.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index 9c8db9a89e..96888da974 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -135,6 +135,7 @@ void unary_op_gpu_inplace( cu::unary_v, num_blocks, block_dims, + 0, in.data(), out.data(), out.data_size()); @@ -146,6 +147,7 @@ void unary_op_gpu_inplace( cu::unary_g, num_blocks, block_dims, + 0, in.data(), out.data(), out.data_size(),