diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index c8a5a962a..90f8561c1 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -152,35 +152,29 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, { - using InType = cuda_type_t; + dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { + using T = cuda_type_t; constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); - dim3 block_dims{BLOCK_DIM, 1, 1}; - auto kernel = &cu::arg_reduce_general< - InType, - cu::ArgMax, - BLOCK_DIM, - N_READS>; - if (reduce_type_ == ArgReduce::ArgMin) { - kernel = &cu::arg_reduce_general< - InType, - cu::ArgMin, - BLOCK_DIM, - N_READS>; - } - kernel<<>>( - in.data(), - out.data(), - out.size(), - const_param(shape), - const_param(in_strides), - const_param(out_strides), - ndim, - axis_stride, - axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); + auto kernel = + cu::arg_reduce_general, block_dim(), N_READS>; + if (reduce_type_ == ArgReduce::ArgMin) { + kernel = cu:: + arg_reduce_general, block_dim(), N_READS>; + } + kernel<<>>( + in.data(), + out.data(), + out.size(), + const_param(shape), + const_param(in_strides), + const_param(out_strides), + ndim, + axis_stride, + axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 9c437cde9..8e476d30f 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -140,54 +140,64 @@ void binary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - bool large = a.data_size() > INT32_MAX || - b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = - &cu::binary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } }); - } else { - auto kernel = cu::binary_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; @@ -197,7 +207,7 @@ void binary_op_gpu_inplace( kernel = cu::binary_vv; } auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); + kernel, out.data_size(), out.shape(), out.strides(), large()); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 074c947da..0a68e5f1d 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -138,57 +138,67 @@ void binary_op_gpu_inplace( encoder.set_output_array(out_a); encoder.set_output_array(out_b); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, { + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; auto bopt = get_binary_op_type(a, b); if (bopt == BinaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, out_a); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - bool large = a.data_size() > INT32_MAX || - b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = - cu::binary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + out_a.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = + collapse_contiguous_dims(a, b, out_a); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::binary_g_nd< + Op, + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides)); + }); + } else { + auto kernel = cu::binary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out_a, large()); + kernel<<>>( + a.data(), + b.data(), + out_a.data(), + out_b.data(), + out_a.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + ndim); + } }); - } else { - auto kernel = cu::binary_g; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out_a, large); - kernel<<>>( - a.data(), - b.data(), - out_a.data(), - out_b.data(), - out_a.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::binary_ss; if (bopt == BinaryOpType::ScalarVector) { kernel = cu::binary_sv; @@ -202,7 +212,7 @@ void binary_op_gpu_inplace( out_a.data_size(), out_a.shape(), out_a.strides(), - LARGE); + large()); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index 789826507..e80fdec8c 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -10,15 +10,6 @@ namespace mlx::core { -#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ - using InType = cuda_type_t; \ - using OutType = cuda_type_t; \ - __VA_ARGS__; \ - }); \ - }) - void copy_contiguous( cu::CommandEncoder& encoder, CopyType ctype, diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 5f4c9ca8f..15858ded0 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -36,19 +36,23 @@ void copy_contiguous( int64_t in_offset, int64_t out_offset) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; - auto kernel = cu::copy_s; - if (ctype == CopyType::Vector) { - kernel = cu::copy_v; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); - kernel<<>>( - in.data() + in_offset, - out.data() + out_offset, - out.data_size()); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), large()); + kernel<<>>( + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 2dc08c60a..b2703e4bf 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -56,42 +56,48 @@ void copy_general( const Strides& strides_in, const Strides& strides_out) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - size_t data_size = 1; - for (auto& s : shape) - data_size *= s; - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_gg_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, data_size, shape, out.strides(), large); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg; - auto [num_blocks, block_dims] = - get_launch_args(kernel, data_size, shape, out.strides(), large); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim); - } + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto ndim_constant) { + auto kernel = + cu::copy_gg_nd; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + kernel<<>>( + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large()); + kernel<<>>( + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 2e1cf4fba..68ad005d2 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -62,41 +62,52 @@ void copy_general_dynamic( const array& dynamic_offset_in, const array& dynamic_offset_out) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_gg_dynamic_nd; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - dynamic_offset_in.data(), - dynamic_offset_out.data()); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg_dynamic; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim, - dynamic_offset_in.data(), - dynamic_offset_out.data()); - } + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = cu::copy_gg_dynamic_nd< + InType, + OutType, + IdxT, + dims_constant()>; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + dynamic_offset_in.data(), + dynamic_offset_out.data()); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, + dynamic_offset_in.data(), + dynamic_offset_out.data()); + } + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index a3bb37e53..d83ba0854 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -51,35 +51,43 @@ void copy_general_input( const Shape& shape, const Strides& strides_in) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_g_nd; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_g; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - ndim); - } + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + dispatch_bool( + in.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + using IdxT = std::conditional_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::copy_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + ndim); + } + }); }); }); }); diff --git a/mlx/backend/cuda/kernel_utils.cuh b/mlx/backend/cuda/kernel_utils.cuh index b1fe875bd..b0058b618 100644 --- a/mlx/backend/cuda/kernel_utils.cuh +++ b/mlx/backend/cuda/kernel_utils.cuh @@ -6,6 +6,8 @@ #pragma once +#include + #include "mlx/array.h" #include "mlx/backend/cuda/device/utils.cuh" @@ -17,60 +19,46 @@ namespace mlx::core { -// Convert a number between 1~3 to constexpr. -#define MLX_SWITCH_1_2_3(N, NDIM, ...) \ - switch (N) { \ - case 1: { \ - constexpr int NDIM = 1; \ - __VA_ARGS__; \ - break; \ - } \ - case 2: { \ - constexpr int NDIM = 2; \ - __VA_ARGS__; \ - break; \ - } \ - case 3: { \ - constexpr int NDIM = 3; \ - __VA_ARGS__; \ - break; \ - } \ +template +void dispatch_1_2_3(int n, F&& f) { + switch (n) { + case 1: + f(std::integral_constant{}); + break; + case 2: + f(std::integral_constant{}); + break; + case 3: + f(std::integral_constant{}); + break; } +} -// Like MLX_SWITCH_ALL_TYPES but for booleans. -#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \ - if (BOOL) { \ - constexpr bool BOOL_ALIAS = true; \ - __VA_ARGS__; \ - } else { \ - constexpr bool BOOL_ALIAS = false; \ - __VA_ARGS__; \ +template +void dispatch_bool(bool v, F&& f) { + if (v) { + f(std::true_type{}); + } else { + f(std::false_type{}); } +} -// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2. -#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \ - { \ - uint32_t _num_threads = NUM_THREADS; \ - if (_num_threads <= WARP_SIZE) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 2) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 4) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 8) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \ - __VA_ARGS__; \ - } else if (_num_threads <= WARP_SIZE * 16) { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \ - __VA_ARGS__; \ - } else { \ - constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \ - __VA_ARGS__; \ - } \ +template +void dispatch_block_dim(int threads, F&& f) { + if (threads <= WARP_SIZE) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 2) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 4) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 8) { + f(std::integral_constant{}); + } else if (threads <= WARP_SIZE * 16) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); } +} // Maps CPU types to CUDA types. template diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index c71795fad..23f0b168f 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -259,21 +259,22 @@ void LayerNorm::eval_gpu( encoder.set_input_array(b); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::layer_norm; - kernel<<>>( - x.data(), - w.data(), - b.data(), - out.data(), - eps_, - axis_size, - w_stride, - b_stride); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm; + kernel<<>>( + x.data(), + w.data(), + b.data(), + out.data(), + eps_, + axis_size, + w_stride, + b_stride); + }); }); }); } @@ -357,22 +358,27 @@ void LayerNormVJP::eval_gpu( encoder.set_output_array(gx); encoder.set_output_array(gw_temp); encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - MLX_SWITCH_BOOL(has_w, HAS_W, { - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::layer_norm_vjp; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::layer_norm_vjp< + DataType, + has_w_constant(), + block_dim(), + N_READS>; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); }); diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index f57f82ea8..5d6bf437d 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -144,14 +144,15 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { constexpr int N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::logsumexp; - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::logsumexp; + kernel<<>>( + in.data(), out.data(), axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index e32befc9c..715e5a232 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -28,7 +28,8 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_output_array(out); encoder.launch_kernel([&, this](cudaStream_t stream) { - MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { + dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); using OutType = cuda_type_t; CTYPE step = static_cast(start_ + step_) - static_cast(start_); diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 5a7c28041..a6ccd5ae9 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -111,10 +111,11 @@ void all_reduce( encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(dt, CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; kernel<<>>( static_cast(indata), @@ -135,10 +136,11 @@ void all_reduce( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(dt, CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(dt, [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; kernel<<>>( static_cast(indata), out.data(), block_step, insize); diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 192a9b3e8..78f6b93bc 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -215,11 +215,12 @@ void col_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); @@ -229,7 +230,8 @@ void col_reduce_looped( constexpr int BN = 32; dim3 grid = output_grid_for_col_reduce(out, args, BN); int blocks = BM * BN / N_READS; - auto kernel = cu::col_reduce_looped; + auto kernel = + cu::col_reduce_looped; kernel<<>>(indata, out.data(), args); }); }); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 50fe109c4..296a4e611 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -33,10 +33,11 @@ void init_reduce( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::init_reduce; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); diff --git a/mlx/backend/cuda/reduce/reduce.cuh b/mlx/backend/cuda/reduce/reduce.cuh index a7262bcc2..d0eb3f5c5 100644 --- a/mlx/backend/cuda/reduce/reduce.cuh +++ b/mlx/backend/cuda/reduce/reduce.cuh @@ -1,5 +1,7 @@ // Copyright © 2025 Apple Inc. +#include + #include "mlx/backend/common/reduce.h" #include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" @@ -9,43 +11,35 @@ namespace mlx::core { -// Dispatch dynamic ndim to constexpr. -// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file. -#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \ - if (ndim == 1) { \ - constexpr uint32_t NDIM = 1; \ - __VA_ARGS__; \ - } else if (ndim == 2) { \ - constexpr uint32_t NDIM = 2; \ - __VA_ARGS__; \ - } else { \ - constexpr uint32_t NDIM = 5; \ - __VA_ARGS__; \ +template +void dispatch_reduce_ndim(int ndim, F&& f) { + if (ndim == 1) { + f(std::integral_constant{}); + } else if (ndim == 2) { + f(std::integral_constant{}); + } else { + f(std::integral_constant{}); } +} -// Dispatch reduce ops to constexpr. -#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \ - if (REDUCE == Reduce::ReduceType::And) { \ - using OP = cu::And; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Or) { \ - using OP = cu::Or; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Sum) { \ - using OP = cu::Sum; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Prod) { \ - using OP = cu::Prod; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Max) { \ - using OP = cu::Max; \ - __VA_ARGS__; \ - } else if (REDUCE == Reduce::ReduceType::Min) { \ - using OP = cu::Min; \ - __VA_ARGS__; \ - } else { \ - throw std::invalid_argument("Unknown reduce type."); \ +template +void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) { + if (reduce_type == Reduce::ReduceType::And) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Or) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Sum) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Prod) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Max) { + f(type_identity{}); + } else if (reduce_type == Reduce::ReduceType::Min) { + f(type_identity{}); + } else { + throw std::invalid_argument("Unknown reduce type."); } +} void all_reduce( cu::CommandEncoder& encoder, diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 6a8a35311..4578dbad0 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -246,10 +246,11 @@ void row_reduce_simple( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); @@ -293,10 +294,11 @@ void row_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { - MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + dispatch_all_types(in.dtype(), [&](auto type_tag) { + dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) { + using OP = MLX_GET_TYPE(reduce_type_tag); + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); @@ -311,10 +313,16 @@ void row_reduce_looped( // Pick the kernel auto kernel = cu::row_reduce_looped; - MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { - MLX_SWITCH_BLOCK_DIM(threads, THREADS, { - kernel = cu::row_reduce_looped; - block.x = THREADS; + dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) { + dispatch_block_dim(threads, [&](auto threads_constant) { + kernel = cu::row_reduce_looped< + T, + U, + OP, + reduce_ndim(), + threads_constant(), + N_READS>; + block.x = threads_constant(); }); }); diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 3c521b90d..7b87f2947 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -225,19 +225,20 @@ void RMSNorm::eval_gpu( encoder.set_input_array(w); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { constexpr uint32_t N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::rms_norm; - kernel<<>>( - x.data(), - w.data(), - out.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::rms_norm; + kernel<<>>( + x.data(), + w.data(), + out.data(), + eps_, + axis_size, + w_stride); + }); }); }); } @@ -311,22 +312,28 @@ void RMSNormVJP::eval_gpu( encoder.set_output_array(gx); encoder.set_output_array(gw_temp); encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, { - using DataType = cuda_type_t; - constexpr int N_READS = 4; - MLX_SWITCH_BOOL(has_w, HAS_W, { - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::rms_norm_vjp; - kernel<<>>( - x.data(), - w.data(), - g.data(), - gx.data(), - gw_temp.data(), - eps_, - axis_size, - w_stride); - }); + dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { + dispatch_bool(has_w, [&](auto has_w_constant) { + constexpr int N_READS = 4; + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + constexpr int N_READS = 4; + auto kernel = cu::rms_norm_vjp< + DataType, + has_w_constant(), + block_dim(), + N_READS>; + kernel<<>>( + x.data(), + w.data(), + g.data(), + gx.data(), + gw_temp.data(), + eps_, + axis_size, + w_stride); + }); }); }); }); diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 1d8307811..a7d7b27ce 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -310,12 +310,12 @@ void RoPE::eval_gpu( encoder.set_input_array(offset); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, { - using DataType = cuda_type_t; - MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { - MLX_SWITCH_BOOL(forward_, FORWARD, { + dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + dispatch_bool(traditional_, [&](auto traditional) { + dispatch_bool(forward_, [&](auto forward) { + using DataType = cuda_type_t; if (single && !with_freqs) { - auto kernel = cu::rope_single; + auto kernel = cu::rope_single; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -327,7 +327,8 @@ void RoPE::eval_gpu( mat_size, dims); } else if (single) { - auto kernel = cu::rope_single_freqs; + auto kernel = + cu::rope_single_freqs; uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); kernel<<>>( @@ -340,7 +341,7 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else if (with_freqs) { - auto kernel = cu::rope_freqs; + auto kernel = cu::rope_freqs; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; @@ -358,7 +359,7 @@ void RoPE::eval_gpu( dims, inputs[2].strides(0)); } else { - auto kernel = cu::rope; + auto kernel = cu::rope; uint3 dims = make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); dims.z = (dims.z + 3) / 4; diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 652e6da19..af9ddf214 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -142,17 +142,18 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { constexpr int N_READS = 4; - MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { - auto kernel = cu::softmax; - if (precise) { - kernel = cu::softmax; - } - kernel<<>>( - in.data(), out.data(), axis_size); - }); + dispatch_block_dim( + cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) { + using DataType = cuda_type_t; + auto kernel = cu::softmax; + if (precise) { + kernel = cu::softmax; + } + kernel<<>>( + in.data(), out.data(), axis_size); + }); }); }); } diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 5cbffc0f4..2c5599bed 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -76,6 +76,14 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { temp.data(), size, args...)); } +struct OffsetTransform { + int nsort; + + int __device__ operator()(int i) { + return i * nsort; + } +}; + void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array out = out_; auto& encoder = cu::get_command_encoder(s); @@ -106,12 +114,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); if constexpr (!std::is_same_v) { using Type = cuda_type_t; auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - [nsort] __device__(int i) { return i * nsort; }); + thrust::make_counting_iterator(0), OffsetTransform{nsort}); if (argsort) { // Indices in the sorted dimension. array indices( diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index e33af3c80..1d6535100 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -92,58 +92,63 @@ void ternary_op_gpu_inplace( encoder.set_input_array(c); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, { - using DType = cuda_type_t; + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using DType = cuda_type_t; auto topt = get_ternary_op_type(a, b, c); if (topt == TernaryOpType::General) { - auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); - auto& a_strides = strides[0]; - auto& b_strides = strides[1]; - auto& c_strides = strides[2]; - bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || - c.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::ternary_g_nd; - auto [num_blocks, block_dims] = - get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides)); + dispatch_bool( + a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || + c.data_size() > INT32_MAX || out.data_size() > INT32_MAX, + [&](auto large) { + using IdxT = std::conditional_t; + Shape shape; + std::vector strides; + std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out); + auto& a_strides = strides[0]; + auto& b_strides = strides[1]; + auto& c_strides = strides[2]; + int ndim = shape.size(); + if (ndim <= 3) { + dispatch_1_2_3(ndim, [&](auto dims_constant) { + auto kernel = + cu::ternary_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides)); + }); + } else { + auto kernel = cu::ternary_g; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large()); + kernel<<>>( + a.data(), + b.data(), + c.data(), + out.data(), + out.data_size(), + const_param(shape), + const_param(a_strides), + const_param(b_strides), + const_param(c_strides), + ndim); + } }); - } else { - auto kernel = cu::ternary_g; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - a.data(), - b.data(), - c.data(), - out.data(), - out.data_size(), - const_param(shape), - const_param(a_strides), - const_param(b_strides), - const_param(c_strides), - ndim); - } - }); } else { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; + dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; auto kernel = cu::ternary_v; auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); + kernel, out.data_size(), out.shape(), out.strides(), large()); kernel<<>>( a.data(), b.data(), diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index e45144eda..4f9bac29f 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -79,8 +79,10 @@ void unary_op_gpu_inplace( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 4a3d8be30..35731f6eb 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -25,22 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) { } const char* dtype_to_cuda_type(const Dtype& dtype) { - if (dtype == float16) { - return "__half"; + switch (dtype) { + case bool_: + return "bool"; + case int8: + return "int8_t"; + case int16: + return "int16_t"; + case int32: + return "int32_t"; + case int64: + return "int64_t"; + case uint8: + return "uint8_t"; + case uint16: + return "uint16_t"; + case uint32: + return "uint32_t"; + case uint64: + return "uint64_t"; + case float16: + return "__half"; + case bfloat16: + return "__nv_bfloat16"; + case float32: + return "float"; + case float64: + return "double"; + case complex64: + return "cuComplex"; + default: + return "unknown"; } - if (dtype == bfloat16) { - return "__nv_bfloat16"; - } - if (dtype == complex64) { - return "cuComplex"; - } -#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ - if (dtype == DTYPE) { \ - return #CPP_TYPE; \ - } - MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) -#undef SPECIALIZE_DtypeToString - return nullptr; } } // namespace mlx::core diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp index a4448536d..9f10e6a9a 100644 --- a/mlx/dtype_utils.cpp +++ b/mlx/dtype_utils.cpp @@ -5,16 +5,38 @@ namespace mlx::core { const char* dtype_to_string(Dtype arg) { - if (arg == bool_) { - return "bool"; + switch (arg) { + case bool_: + return "bool"; + case int8: + return "int8"; + case int16: + return "int16"; + case int32: + return "int32"; + case int64: + return "int64"; + case uint8: + return "uint8"; + case uint16: + return "uint16"; + case uint32: + return "uint32"; + case uint64: + return "uint64"; + case float16: + return "float16"; + case bfloat16: + return "bfloat16"; + case float32: + return "float32"; + case float64: + return "float64"; + case complex64: + return "complex64"; + default: + return "unknown"; } -#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ - if (DTYPE == arg) { \ - return #DTYPE; \ - } - MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) -#undef SPECIALIZE_DtypeToString - return "(unknown)"; } } // namespace mlx::core diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h index 55de890f2..27fe432f6 100644 --- a/mlx/dtype_utils.h +++ b/mlx/dtype_utils.h @@ -1,207 +1,106 @@ // Copyright © 2025 Apple Inc. -// Copyright © Meta Platforms, Inc. and affiliates. -// -// This source code is licensed under the BSD-style license found in -// https://github.com/pytorch/executorch/blob/main/LICENSE -// -// Forked from -// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h #pragma once -#include "mlx/dtype.h" +#include -#include +#include "mlx/dtype.h" +#include "mlx/utils.h" namespace mlx::core { // Return string representation of dtype. const char* dtype_to_string(Dtype arg); -// Macros that iterate across different subsets of Dtypes. -// -// For all of these macros, the final `_` parameter is the name of another macro -// that takes two parameters: the name of a C type, and the name of the -// corresponding Dtype enumerator. -// -// Note that these macros should use fully-qualified namespaces (starting with -// `::`) to ensure that they can be called safely in any arbitrary namespace. -#define MLX_FORALL_INT_TYPES(_) \ - _(uint8_t, uint8) \ - _(uint16_t, uint16) \ - _(uint32_t, uint32) \ - _(uint64_t, uint64) \ - _(int8_t, int8) \ - _(int16_t, int16) \ - _(int32_t, int32) \ - _(int64_t, int64) +#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \ + case DTYPE: \ + f(type_identity{}); \ + break -#define MLX_FORALL_FLOAT_TYPES(_) \ - _(float16_t, float16) \ - _(float, float32) \ - _(double, float64) \ - _(bfloat16_t, bfloat16) +#define MLX_INTERNAL_DTYPE_SWITCH_INTS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t) -// Calls the provided macro on every Dtype, providing the C type and the -// Dtype name to each call. -// -// @param _ A macro that takes two parameters: the name of a C type, and the -// name of the corresponding Dtype enumerator. -#define MLX_FORALL_DTYPES(_) \ - MLX_FORALL_INT_TYPES(_) \ - MLX_FORALL_FLOAT_TYPES(_) \ - _(bool, bool_) \ - _(complex64_t, complex64) +#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double) -// Maps Dtypes to C++ types. -template -struct DtypeToCppType; - -#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \ - template <> \ - struct DtypeToCppType { \ - using type = CPP_TYPE; \ - }; - -MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType) - -#undef SPECIALIZE_DtypeToCppType - -// Maps C++ types to Dtypes. +// This already exists in C++20 but in C++20 we can also just use templated +// lambdas which will make this so much nicer. template -struct CppTypeToDtype; +struct type_identity { + using type = T; +}; -#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \ - template <> \ - struct CppTypeToDtype \ - : std::integral_constant {}; +#define MLX_GET_TYPE(x) typename decltype(x)::type +#define MLX_GET_VALUE(x) decltype(x)::value -MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype) - -#undef SPECIALIZE_CppTypeToDtype - -// Helper macros for switch case macros (see below) -// -// These macros are not meant to be used directly. They provide an easy way to -// generate a switch statement that can handle subsets of Dtypes supported. - -#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ - case enum_type: { \ - using CTYPE_ALIAS = ::mlx::core::DtypeToCppType::type; \ - __VA_ARGS__; \ - break; \ +template +void dispatch_all_types(Dtype dt, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); } +} -#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \ - switch (TYPE) { \ - __VA_ARGS__ \ - default: \ - throw std::invalid_argument(fmt::format( \ - "Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \ +template +void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + default: + std::ostringstream msg; + msg << tag << " Only integer types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); } +} -#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__) +template +void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only float types supported but " << dt << " was provided"; + throw std::invalid_argument(msg.str()); + } +} -#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__) +template +void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only integer and float types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} -#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__) - -// Switch case macros -// -// These macros provide an easy way to generate switch statements that apply a -// common lambda function to subsets of Dtypes supported by MLX. -// The lambda function can type specialize to the ctype associated with the -// Dtype being handled through an alias passed as the CTYPE_ALIAS argument. -// -// Arguments: -// - ADDITIONAL: Additional Dtype case to add -// - TYPE: The Dtype to handle through the switch statement -// - NAME: A name for this operation which will be used in error messages -// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype. -// - ...: A statement to be applied to each Dtype case -// -// An example usage is: -// -// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, { -// output.data[0] = input.data[0]; -// }); -// -// Note that these can be nested as well: -// -// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, { -// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, { -// output.data[0] = input.data[0]; -// }); -// }); -// -// These macros are adapted from Dispatch.h in the ATen library. The primary -// difference is that the CTYPE_ALIAS argument is exposed to users, which is -// used to alias the ctype associated with the Dtype that is being handled. - -#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \ - switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) } - -#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) +template +void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only real numbers supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} } // namespace mlx::core diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 61b9da3a2..e53a7a97f 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -253,7 +253,9 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { std::ostream& operator<<(std::ostream& os, array a) { a.eval(); - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array(os, a)); + dispatch_all_types(a.dtype(), [&](auto type_tag) { + print_array(os, a); + }); return os; } @@ -321,8 +323,9 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) { } iinfo::iinfo(Dtype dtype) : dtype(dtype) { - MLX_SWITCH_INT_TYPES_CHECKED( - dtype, "[iinfo]", CTYPE, set_iinfo_limits(min, max)); + dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) { + set_iinfo_limits(min, max); + }); } } // namespace mlx::core