diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index 2e91ae61f..727cb027d 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -155,25 +155,33 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { using T = cuda_type_t; constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims{BLOCK_DIM, 1, 1}; - auto kernel = - cu::arg_reduce_general, BLOCK_DIM, N_READS>; - if (reduce_type_ == ArgReduce::ArgMin) { - kernel = cu::arg_reduce_general, BLOCK_DIM, N_READS>; - } - kernel<<>>( - in.data(), - out.data(), - out.size(), - const_param(shape), - const_param(in_strides), - const_param(out_strides), - ndim, - axis_stride, - axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim_constant) { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + dim3 block_dims{block_dim_constant(), 1, 1}; + auto kernel = cu::arg_reduce_general< + T, + cu::ArgMax, + block_dim_constant(), + N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = cu::arg_reduce_general< + T, + cu::ArgMin, + block_dim_constant(), + N_READS>; + } + kernel<<>>( + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index ecc1f69e0..8e476d30f 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -149,47 +149,55 @@ void binary_op_gpu_inplace( using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - bool large = a.data_size() > INT32_MAX || - b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = - &cu::binary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } }); - } else { - auto kernel = cu::binary_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; @@ -199,7 +207,7 @@ void binary_op_gpu_inplace( kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); + kernel, out.data_size(), out.shape(), out.strides(), large()); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 7cc428281..3cdf2a89f 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -148,49 +148,54 @@ void binary_op_gpu_inplace( auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, out_a); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - bool large = a.data_size() > INT32_MAX || - b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = - cu::binary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out_a.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + auto [shape, strides] = collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } }); - } else { - auto kernel = cu::binary_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; @@ -204,7 +209,7 @@ void binary_op_gpu_inplace( out_a.data_size(), out_a.shape(), out_a.strides(), - LARGE); + large()); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 84b8f8aa8..15858ded0 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -38,16 +38,16 @@ void copy_contiguous( encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; auto kernel = cu::copy_s; if (ctype == CopyType::Vector) { kernel = cu::copy_v; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); + kernel, out.data_size(), out.shape(), out.strides(), large()); kernel<<>>( in.data() + in_offset, out.data() + out_offset, diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index a6ef33896..b2703e4bf 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -58,44 +58,46 @@ void copy_general( encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - size_t data_size = 1; - for (auto& s : shape) - data_size *= s; - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_gg_nd; - auto [num_blocks, block_dims] = get_launch_args( - kernel, data_size, shape, out.strides(), large); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out)); + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + auto kernel = + cu::copy_gg_nd; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + kernel<<>>( + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + kernel<<>>( + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg; - auto [num_blocks, block_dims] = - get_launch_args(kernel, data_size, shape, out.strides(), large); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim); - } - }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 39e50f5eb..68ad005d2 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -64,44 +64,50 @@ void copy_general_dynamic( encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_gg_dynamic_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - dynamic_offset_in.data(), - dynamic_offset_out.data()); + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::copy_gg_dynamic_nd< + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + dynamic_offset_in.data(), + dynamic_offset_out.data()); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg_dynamic; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim, - dynamic_offset_in.data(), - dynamic_offset_out.data()); - } - }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index d025a5a67..d83ba0854 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -53,38 +53,41 @@ void copy_general_input( encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) { - using InType = cuda_type_t; - using OutType = cuda_type_t; - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in)); + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::copy_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + ndim); + } }); - } else { // ndim >= 4 - auto kernel = cu::copy_g; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - ndim); - } - }); }); }); }); diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index b1fe875bd..b0058b618 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -6,6 +6,8 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/backend/cuda/device/utils.cuh" @@ -17,60 +19,46 @@ namespace mlx::core { -// Convert a number between 1~3 to constexpr. -#define MLX_SWITCH_1_2_3(N, NDIM, ...) \ - switch (N) { \ - case 1: { \ - constexpr int NDIM = 1; \ - __VA_ARGS__; \ - break; \ - } \ - case 2: { \ - constexpr int NDIM = 2; \ - __VA_ARGS__; \ - break; \ - } \ - case 3: { \ - constexpr int NDIM = 3; \ - __VA_ARGS__; \ - break; \ - } \ +template +void dispatch_1_2_3(int n, F&& f) { + switch (n) { + case 1: + f(std::integral_constant{}); + break; + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; } +} -// Like MLX_SWITCH_ALL_TYPES but for booleans. -#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \ - if (BOOL) { \ - constexpr bool BOOL_ALIAS = true; \ - __VA_ARGS__; \ - } else { \ - constexpr bool BOOL_ALIAS = false; \ - __VA_ARGS__; \ +template +void dispatch_bool(bool v, F&& f) { + if (v) { + f(std::true_type{}); + } else { + f(std::false_type{}); } +} -// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2. -#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \ - { \ - uint32_t _num_threads = NUM_THREADS; \ - if (_num_threads <= WARP_SIZE) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 2) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 4) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 8) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 16) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \ - __VA_ARGS__; \ - } else { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \ - __VA_ARGS__; \ - } \ +template +void dispatch_block_dim(int threads, F&& f) { + if (threads <= WARP_SIZE) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 2) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 4) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 8) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 16) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); } +} // Maps CPU types to CUDA types. template diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index 27e4ddd9d..23f0b168f 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -260,20 +260,21 @@ void LayerNorm::eval_gpu( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { - using DataType = cuda_type_t; constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::layer_norm; - kernel<<>>( - x.data(), - w.data(), - b.data(), - out.data(), - eps_, - axis_size, - w_stride, - b_stride); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm; + kernel<<>>( + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); + }); }); }); } @@ -358,21 +359,26 @@ void LayerNormVJP::eval_gpu( encoder.set_output_array(gw_temp); encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - MLX_SWITCH_BOOL(has_w, HAS_W, { - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::layer_norm_vjp; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm_vjp< + DataType, + has_w_constant(), + block_dim(), + N_READS>; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); }); diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index 3615c8291..5d6bf437d 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -145,13 +145,14 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { - using DataType = cuda_type_t; constexpr int N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::logsumexp; - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::logsumexp; + kernel<<>>( + in.data(), out.data(), axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 51f644622..a6ccd5ae9 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -112,7 +112,8 @@ void all_reduce( encoder.set_output_array(intermediate); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(dt, [&](auto type_tag) { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; @@ -136,7 +137,8 @@ void all_reduce( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(dt, [&](auto type_tag) { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index e88b09f19..78f6b93bc 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -216,10 +216,10 @@ void col_reduce_looped( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - using T = cuda_type_t; + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) @@ -230,7 +230,8 @@ void col_reduce_looped( constexpr int BN = 32; dim3 grid = output_grid_for_col_reduce(out, args, BN); int blocks = BM * BN / N_READS; - auto kernel = cu::col_reduce_looped; + auto kernel = + cu::col_reduce_looped; kernel<<>>(indata, out.data(), args); }); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 0f9e7a993..296a4e611 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -34,7 +34,8 @@ void init_reduce( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); using T = cuda_type_t; using U = typename cu::ReduceResult::type; auto kernel = cu::init_reduce; diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index a7262bcc2..d0eb3f5c5 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/common/reduce.h" #include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" @@ -9,43 +11,35 @@ namespace mlx::core { -// Dispatch dynamic ndim to constexpr. -// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file. -#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \ - if (ndim == 1) { \ - constexpr uint32_t NDIM = 1; \ - __VA_ARGS__; \ - } else if (ndim == 2) { \ - constexpr uint32_t NDIM = 2; \ - __VA_ARGS__; \ - } else { \ - constexpr uint32_t NDIM = 5; \ - __VA_ARGS__; \ +template +void dispatch_reduce_ndim(int ndim, F&& f) { + if (ndim == 1) { + f(std::integral_constant{}); + } else if (ndim == 2) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); } +} -// Dispatch reduce ops to constexpr. -#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \ - if (REDUCE == Reduce::ReduceType::And) { \ - using OP = cu::And; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Or) { \ - using OP = cu::Or; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Sum) { \ - using OP = cu::Sum; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Prod) { \ - using OP = cu::Prod; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Max) { \ - using OP = cu::Max; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Min) { \ - using OP = cu::Min; \ - __VA_ARGS__; \ - } else { \ - throw std::invalid_argument("Unknown reduce type."); \ +template +void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { + if (reduce_type == Reduce::ReduceType::And) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Or) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Sum) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Prod) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Max) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Min) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); } +} void all_reduce( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 88cc2a1fc..4578dbad0 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -247,9 +247,9 @@ void row_reduce_simple( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) @@ -295,9 +295,9 @@ void row_reduce_looped( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_all_types(in.dtype(), [&](auto type_tag) { - using CTYPE = MLX_GET_TYPE(type_tag); - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) @@ -313,10 +313,16 @@ void row_reduce_looped( // Pick the kernel auto kernel = cu::row_reduce_looped; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - MLX_SWITCH_BLOCK_DIM(threads, THREADS, { - kernel = cu::row_reduce_looped; - block.x = THREADS; + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + dispatch_block_dim(threads, [&](auto threads_constant) { + kernel = cu::row_reduce_looped< + T, + U, + OP, + reduce_ndim(), + threads_constant(), + N_READS>; + block.x = threads_constant(); }); }); diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index bc4c86666..7b87f2947 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -226,18 +226,19 @@ void RMSNorm::eval_gpu( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { - using DataType = cuda_type_t; constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::rms_norm; - kernel<<>>( - x.data(), - w.data(), - out.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::rms_norm; + kernel<<>>( + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); + }); }); }); } @@ -312,21 +313,27 @@ void RMSNormVJP::eval_gpu( encoder.set_output_array(gw_temp); encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - MLX_SWITCH_BOOL(has_w, HAS_W, { - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::rms_norm_vjp; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + auto kernel = cu::rms_norm_vjp< + DataType, + has_w_constant(), + block_dim(), + N_READS>; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); }); diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index a1081622b..a7d7b27ce 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -311,11 +311,11 @@ void RoPE::eval_gpu( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { - using DataType = cuda_type_t; - MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { - MLX_SWITCH_BOOL(forward_, FORWARD, { + dispatch_bool(traditional_, [&](auto traditional) { + dispatch_bool(forward_, [&](auto forward) { + using DataType = cuda_type_t; if (single && !with_freqs) { - auto kernel = cu::rope_single; + auto kernel = cu::rope_single; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -327,7 +327,8 @@ void RoPE::eval_gpu( mat_size, dims); } else if (single) { - auto kernel = cu::rope_single_freqs; + auto kernel = + cu::rope_single_freqs; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -340,7 +341,7 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else if (with_freqs) { - auto kernel = cu::rope_freqs; + auto kernel = cu::rope_freqs; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; @@ -358,7 +359,7 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else { - auto kernel = cu::rope; + auto kernel = cu::rope; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index e7fb14b11..af9ddf214 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -143,16 +143,17 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { - using DataType = cuda_type_t; constexpr int N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::softmax; - if (precise) { - kernel = cu::softmax; - } - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + kernel<<>>( + in.data(), out.data(), axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index 080e789f6..06c063192 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -97,53 +97,56 @@ void ternary_op_gpu_inplace( auto topt = get_ternary_op_type(a, b, c); if (topt == TernaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - auto& c_strides = strides[2]; - bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - c.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::ternary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::ternary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } }); - } else { - auto kernel = cu::ternary_g; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.data_size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); + kernel, out.data_size(), out.shape(), out.strides(), large()); kernel<<>>( a.data(), b.data(),