From 45c43dd24ab6701917f97b0b475157810fa92d09 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 29 Jun 2025 02:29:23 -0700 Subject: [PATCH] Start changing MLX_SWITCH to templates --- mlx/backend/cuda/arg_reduce.cu | 19 +- mlx/backend/cuda/binary.cu | 6 +- mlx/backend/cuda/binary_two.cu | 6 +- mlx/backend/cuda/copy/copy.cuh | 9 - mlx/backend/cuda/copy/copy_contiguous.cu | 30 +- mlx/backend/cuda/copy/copy_general.cu | 64 +++-- mlx/backend/cuda/copy/copy_general_dynamic.cu | 61 ++-- mlx/backend/cuda/copy/copy_general_input.cu | 53 ++-- mlx/backend/cuda/layer_norm.cu | 8 +- mlx/backend/cuda/logsumexp.cu | 4 +- mlx/backend/cuda/primitives.cu | 3 +- mlx/backend/cuda/reduce/all_reduce.cu | 12 +- mlx/backend/cuda/reduce/col_reduce.cu | 5 +- mlx/backend/cuda/reduce/init_reduce.cu | 6 +- mlx/backend/cuda/reduce/row_reduce.cu | 10 +- mlx/backend/cuda/rms_norm.cu | 8 +- mlx/backend/cuda/rope.cu | 4 +- mlx/backend/cuda/softmax.cu | 4 +- mlx/backend/cuda/sort.cu | 14 +- mlx/backend/cuda/ternary.cu | 4 +- mlx/backend/cuda/unary.cu | 6 +- mlx/backend/cuda/utils.cpp | 8 +- mlx/dtype_utils.cpp | 39 ++- mlx/dtype_utils.h | 263 ++++++------------ mlx/utils.cpp | 9 +- 25 files changed, 296 insertions(+), 359 deletions(-) diff --git a/mlx/backend/cuda/arg_reduce.cu b/mlx/backend/cuda/arg_reduce.cu index c8a5a962a..2e91ae61f 100644 --- a/mlx/backend/cuda/arg_reduce.cu +++ b/mlx/backend/cuda/arg_reduce.cu @@ -152,26 +152,19 @@ void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, { - using InType = cuda_type_t; + dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { + using T = cuda_type_t; constexpr uint32_t N_READS = 4; MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); dim3 block_dims{BLOCK_DIM, 1, 1}; - auto kernel = &cu::arg_reduce_general< - InType, - cu::ArgMax, - BLOCK_DIM, - N_READS>; + auto kernel = + cu::arg_reduce_general, BLOCK_DIM, N_READS>; if (reduce_type_ == ArgReduce::ArgMin) { - kernel = &cu::arg_reduce_general< - InType, - cu::ArgMin, - BLOCK_DIM, - N_READS>; + kernel = cu::arg_reduce_general, BLOCK_DIM, N_READS>; } kernel<<>>( - in.data(), + in.data(), out.data(), out.size(), const_param(shape), diff --git a/mlx/backend/cuda/binary.cu b/mlx/backend/cuda/binary.cu index 9c437cde9..ecc1f69e0 100644 --- a/mlx/backend/cuda/binary.cu +++ b/mlx/backend/cuda/binary.cu @@ -140,8 +140,10 @@ void binary_op_gpu_inplace( encoder.set_input_array(b); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; diff --git a/mlx/backend/cuda/binary_two.cu b/mlx/backend/cuda/binary_two.cu index 074c947da..7cc428281 100644 --- a/mlx/backend/cuda/binary_two.cu +++ b/mlx/backend/cuda/binary_two.cu @@ -138,8 +138,10 @@ void binary_op_gpu_inplace( encoder.set_output_array(out_a); encoder.set_output_array(out_b); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, { + dispatch_all_types(a.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_binary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; diff --git a/mlx/backend/cuda/copy/copy.cuh b/mlx/backend/cuda/copy/copy.cuh index 789826507..e80fdec8c 100644 --- a/mlx/backend/cuda/copy/copy.cuh +++ b/mlx/backend/cuda/copy/copy.cuh @@ -10,15 +10,6 @@ namespace mlx::core { -#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \ - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \ - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \ - using InType = cuda_type_t; \ - using OutType = cuda_type_t; \ - __VA_ARGS__; \ - }); \ - }) - void copy_contiguous( cu::CommandEncoder& encoder, CopyType ctype, diff --git a/mlx/backend/cuda/copy/copy_contiguous.cu b/mlx/backend/cuda/copy/copy_contiguous.cu index 5f4c9ca8f..84b8f8aa8 100644 --- a/mlx/backend/cuda/copy/copy_contiguous.cu +++ b/mlx/backend/cuda/copy/copy_contiguous.cu @@ -36,19 +36,23 @@ void copy_contiguous( int64_t in_offset, int64_t out_offset) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { - using IdxT = std::conditional_t; - auto kernel = cu::copy_s; - if (ctype == CopyType::Vector) { - kernel = cu::copy_v; - } - auto [num_blocks, block_dims] = get_launch_args( - kernel, out.data_size(), out.shape(), out.strides(), LARGE); - kernel<<>>( - in.data() + in_offset, - out.data() + out_offset, - out.data_size()); + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { + using IdxT = std::conditional_t; + auto kernel = cu::copy_s; + if (ctype == CopyType::Vector) { + kernel = cu::copy_v; + } + auto [num_blocks, block_dims] = get_launch_args( + kernel, out.data_size(), out.shape(), out.strides(), LARGE); + kernel<<>>( + in.data() + in_offset, + out.data() + out_offset, + out.data_size()); + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general.cu b/mlx/backend/cuda/copy/copy_general.cu index 2dc08c60a..a6ef33896 100644 --- a/mlx/backend/cuda/copy/copy_general.cu +++ b/mlx/backend/cuda/copy/copy_general.cu @@ -56,42 +56,46 @@ void copy_general( const Strides& strides_in, const Strides& strides_out) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - size_t data_size = 1; - for (auto& s : shape) - data_size *= s; - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_gg_nd; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + size_t data_size = 1; + for (auto& s : shape) + data_size *= s; + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_gg_nd; + auto [num_blocks, block_dims] = get_launch_args( + kernel, data_size, shape, out.strides(), large); + kernel<<>>( + in_ptr, + out_ptr, + data_size, + const_param(shape), + const_param(strides_in), + const_param(strides_out)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg; auto [num_blocks, block_dims] = get_launch_args(kernel, data_size, shape, out.strides(), large); kernel<<>>( in_ptr, out_ptr, data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg; - auto [num_blocks, block_dims] = - get_launch_args(kernel, data_size, shape, out.strides(), large); - kernel<<>>( - in_ptr, - out_ptr, - data_size, - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim); - } + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim); + } + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general_dynamic.cu b/mlx/backend/cuda/copy/copy_general_dynamic.cu index 2e1cf4fba..39e50f5eb 100644 --- a/mlx/backend/cuda/copy/copy_general_dynamic.cu +++ b/mlx/backend/cuda/copy/copy_general_dynamic.cu @@ -62,41 +62,46 @@ void copy_general_dynamic( const array& dynamic_offset_in, const array& dynamic_offset_out) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_gg_dynamic_nd; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_gg_dynamic_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + dynamic_offset_in.data(), + dynamic_offset_out.data()); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_gg_dynamic; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); kernel<<>>( in_ptr, out_ptr, out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), + const_param(shape), + const_param(strides_in), + const_param(strides_out), + ndim, dynamic_offset_in.data(), dynamic_offset_out.data()); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_gg_dynamic; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - const_param(strides_out), - ndim, - dynamic_offset_in.data(), - dynamic_offset_out.data()); - } + } + }); }); }); }); diff --git a/mlx/backend/cuda/copy/copy_general_input.cu b/mlx/backend/cuda/copy/copy_general_input.cu index a3bb37e53..d025a5a67 100644 --- a/mlx/backend/cuda/copy/copy_general_input.cu +++ b/mlx/backend/cuda/copy/copy_general_input.cu @@ -51,35 +51,40 @@ void copy_general_input( const Shape& shape, const Strides& strides_in) { encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, { - const InType* in_ptr = in.data() + offset_in; - OutType* out_ptr = out.data() + offset_out; - bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; - MLX_SWITCH_BOOL(large, LARGE, { - using IdxT = std::conditional_t; - int ndim = shape.size(); - if (ndim <= 3) { - MLX_SWITCH_1_2_3(ndim, NDIM, { - auto kernel = cu::copy_g_nd; + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using InType = cuda_type_t; + using OutType = cuda_type_t; + const InType* in_ptr = in.data() + offset_in; + OutType* out_ptr = out.data() + offset_out; + bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; + MLX_SWITCH_BOOL(large, LARGE, { + using IdxT = std::conditional_t; + int ndim = shape.size(); + if (ndim <= 3) { + MLX_SWITCH_1_2_3(ndim, NDIM, { + auto kernel = cu::copy_g_nd; + auto [num_blocks, block_dims] = + get_launch_args(kernel, out, large); + kernel<<>>( + in_ptr, + out_ptr, + out.size(), + const_param(shape), + const_param(strides_in)); + }); + } else { // ndim >= 4 + auto kernel = cu::copy_g; auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); kernel<<>>( in_ptr, out_ptr, out.size(), - const_param(shape), - const_param(strides_in)); - }); - } else { // ndim >= 4 - auto kernel = cu::copy_g; - auto [num_blocks, block_dims] = get_launch_args(kernel, out, large); - kernel<<>>( - in_ptr, - out_ptr, - out.size(), - const_param(shape), - const_param(strides_in), - ndim); - } + const_param(shape), + const_param(strides_in), + ndim); + } + }); }); }); }); diff --git a/mlx/backend/cuda/layer_norm.cu b/mlx/backend/cuda/layer_norm.cu index c71795fad..27e4ddd9d 100644 --- a/mlx/backend/cuda/layer_norm.cu +++ b/mlx/backend/cuda/layer_norm.cu @@ -259,8 +259,8 @@ void LayerNorm::eval_gpu( encoder.set_input_array(b); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { + using DataType = cuda_type_t; constexpr uint32_t N_READS = 4; MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { auto kernel = cu::layer_norm; @@ -357,8 +357,8 @@ void LayerNormVJP::eval_gpu( encoder.set_output_array(gx); encoder.set_output_array(gw_temp); encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { + using DataType = cuda_type_t; constexpr int N_READS = 4; MLX_SWITCH_BOOL(has_w, HAS_W, { MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { diff --git a/mlx/backend/cuda/logsumexp.cu b/mlx/backend/cuda/logsumexp.cu index f57f82ea8..3615c8291 100644 --- a/mlx/backend/cuda/logsumexp.cu +++ b/mlx/backend/cuda/logsumexp.cu @@ -144,8 +144,8 @@ void LogSumExp::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { + using DataType = cuda_type_t; constexpr int N_READS = 4; MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { auto kernel = cu::logsumexp; diff --git a/mlx/backend/cuda/primitives.cu b/mlx/backend/cuda/primitives.cu index e32befc9c..715e5a232 100644 --- a/mlx/backend/cuda/primitives.cu +++ b/mlx/backend/cuda/primitives.cu @@ -28,7 +28,8 @@ void Arange::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); encoder.set_output_array(out); encoder.launch_kernel([&, this](cudaStream_t stream) { - MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, { + dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); using OutType = cuda_type_t; CTYPE step = static_cast(start_ + step_) - static_cast(start_); diff --git a/mlx/backend/cuda/reduce/all_reduce.cu b/mlx/backend/cuda/reduce/all_reduce.cu index 5a7c28041..51f644622 100644 --- a/mlx/backend/cuda/reduce/all_reduce.cu +++ b/mlx/backend/cuda/reduce/all_reduce.cu @@ -111,10 +111,10 @@ void all_reduce( encoder.add_temporary(intermediate); encoder.set_output_array(intermediate); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(dt, CTYPE, { + dispatch_all_types(dt, [&](auto type_tag) { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; kernel<<>>( static_cast(indata), @@ -135,10 +135,10 @@ void all_reduce( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(dt, CTYPE, { + dispatch_all_types(dt, [&](auto type_tag) { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::all_reduce; kernel<<>>( static_cast(indata), out.data(), block_step, insize); diff --git a/mlx/backend/cuda/reduce/col_reduce.cu b/mlx/backend/cuda/reduce/col_reduce.cu index 192a9b3e8..e88b09f19 100644 --- a/mlx/backend/cuda/reduce/col_reduce.cu +++ b/mlx/backend/cuda/reduce/col_reduce.cu @@ -215,11 +215,12 @@ void col_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { using T = cuda_type_t; - using U = cu::ReduceResult::type; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); diff --git a/mlx/backend/cuda/reduce/init_reduce.cu b/mlx/backend/cuda/reduce/init_reduce.cu index 50fe109c4..0f9e7a993 100644 --- a/mlx/backend/cuda/reduce/init_reduce.cu +++ b/mlx/backend/cuda/reduce/init_reduce.cu @@ -33,10 +33,10 @@ void init_reduce( encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + dispatch_all_types(in.dtype(), [&](auto type_tag) { MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { - using T = cuda_type_t; - using U = cu::ReduceResult::type; + using T = cuda_type_t; + using U = typename cu::ReduceResult::type; auto kernel = cu::init_reduce; dim3 grid = get_2d_grid_dims(out.shape(), out.strides()); dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1); diff --git a/mlx/backend/cuda/reduce/row_reduce.cu b/mlx/backend/cuda/reduce/row_reduce.cu index 6a8a35311..88cc2a1fc 100644 --- a/mlx/backend/cuda/reduce/row_reduce.cu +++ b/mlx/backend/cuda/reduce/row_reduce.cu @@ -246,10 +246,11 @@ void row_reduce_simple( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using T = cuda_type_t; - using U = cu::ReduceResult::type; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); @@ -293,10 +294,11 @@ void row_reduce_looped( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using T = cuda_type_t; - using U = cu::ReduceResult::type; + using U = typename cu::ReduceResult::type; // Cub doesn't like const pointers for vectorized loads. (sigh) T* indata = const_cast(in.data()); diff --git a/mlx/backend/cuda/rms_norm.cu b/mlx/backend/cuda/rms_norm.cu index 3c521b90d..bc4c86666 100644 --- a/mlx/backend/cuda/rms_norm.cu +++ b/mlx/backend/cuda/rms_norm.cu @@ -225,8 +225,8 @@ void RMSNorm::eval_gpu( encoder.set_input_array(w); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { + using DataType = cuda_type_t; constexpr uint32_t N_READS = 4; MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { auto kernel = cu::rms_norm; @@ -311,8 +311,8 @@ void RMSNormVJP::eval_gpu( encoder.set_output_array(gx); encoder.set_output_array(gw_temp); encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { + using DataType = cuda_type_t; constexpr int N_READS = 4; MLX_SWITCH_BOOL(has_w, HAS_W, { MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 1d8307811..a1081622b 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -310,8 +310,8 @@ void RoPE::eval_gpu( encoder.set_input_array(offset); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + using DataType = cuda_type_t; MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { MLX_SWITCH_BOOL(forward_, FORWARD, { if (single && !with_freqs) { diff --git a/mlx/backend/cuda/softmax.cu b/mlx/backend/cuda/softmax.cu index 652e6da19..e7fb14b11 100644 --- a/mlx/backend/cuda/softmax.cu +++ b/mlx/backend/cuda/softmax.cu @@ -142,8 +142,8 @@ void Softmax::eval_gpu(const std::vector& inputs, array& out) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, { - using DataType = cuda_type_t; + dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { + using DataType = cuda_type_t; constexpr int N_READS = 4; MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { auto kernel = cu::softmax; diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 5cbffc0f4..2c5599bed 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -76,6 +76,14 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) { temp.data(), size, args...)); } +struct OffsetTransform { + int nsort; + + int __device__ operator()(int i) { + return i * nsort; + } +}; + void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { array out = out_; auto& encoder = cu::get_command_encoder(s); @@ -106,12 +114,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) { encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, { + dispatch_all_types(in.dtype(), [&](auto type_tag) { + using CTYPE = MLX_GET_TYPE(type_tag); if constexpr (!std::is_same_v) { using Type = cuda_type_t; auto offsets = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - [nsort] __device__(int i) { return i * nsort; }); + thrust::make_counting_iterator(0), OffsetTransform{nsort}); if (argsort) { // Indices in the sorted dimension. array indices( diff --git a/mlx/backend/cuda/ternary.cu b/mlx/backend/cuda/ternary.cu index e33af3c80..080e789f6 100644 --- a/mlx/backend/cuda/ternary.cu +++ b/mlx/backend/cuda/ternary.cu @@ -92,8 +92,8 @@ void ternary_op_gpu_inplace( encoder.set_input_array(c); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, { - using DType = cuda_type_t; + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using DType = cuda_type_t; auto topt = get_ternary_op_type(a, b, c); if (topt == TernaryOpType::General) { diff --git a/mlx/backend/cuda/unary.cu b/mlx/backend/cuda/unary.cu index e45144eda..4f9bac29f 100644 --- a/mlx/backend/cuda/unary.cu +++ b/mlx/backend/cuda/unary.cu @@ -79,8 +79,10 @@ void unary_op_gpu_inplace( encoder.set_input_array(in); encoder.set_output_array(out); encoder.launch_kernel([&](cudaStream_t stream) { - MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { - MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { + dispatch_all_types(in.dtype(), [&](auto in_type_tag) { + dispatch_all_types(out.dtype(), [&](auto out_type_tag) { + using CTYPE_IN = MLX_GET_TYPE(in_type_tag); + using CTYPE_OUT = MLX_GET_TYPE(out_type_tag); if constexpr (cu::supports_unary_op()) { using InType = cuda_type_t; using OutType = cuda_type_t; diff --git a/mlx/backend/cuda/utils.cpp b/mlx/backend/cuda/utils.cpp index 4a3d8be30..34f22ba7f 100644 --- a/mlx/backend/cuda/utils.cpp +++ b/mlx/backend/cuda/utils.cpp @@ -34,13 +34,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) { if (dtype == complex64) { return "cuComplex"; } -#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ - if (dtype == DTYPE) { \ - return #CPP_TYPE; \ - } - MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) -#undef SPECIALIZE_DtypeToString - return nullptr; + return dtype_to_string(dtype); } } // namespace mlx::core diff --git a/mlx/dtype_utils.cpp b/mlx/dtype_utils.cpp index a4448536d..270949ad6 100644 --- a/mlx/dtype_utils.cpp +++ b/mlx/dtype_utils.cpp @@ -5,16 +5,37 @@ namespace mlx::core { const char* dtype_to_string(Dtype arg) { - if (arg == bool_) { - return "bool"; + switch (arg) { + case bool_: + return "bool"; + case int8: + return "int8"; + case int16: + return "int16"; + case int32: + return "int32"; + case int64: + return "int64"; + case uint8: + return "uint8"; + case uint16: + return "uint16"; + case uint32: + return "uint32"; + case uint64: + return "uint64"; + case float16: + return "float16"; + case bfloat16: + return "bfloat16"; + case float32: + return "float32"; + case float64: + return "float64"; + case complex64: + return "complex64"; } -#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \ - if (DTYPE == arg) { \ - return #DTYPE; \ - } - MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString) -#undef SPECIALIZE_DtypeToString - return "(unknown)"; + return "unknown"; } } // namespace mlx::core diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h index 55de890f2..27fe432f6 100644 --- a/mlx/dtype_utils.h +++ b/mlx/dtype_utils.h @@ -1,207 +1,106 @@ // Copyright © 2025 Apple Inc. -// Copyright © Meta Platforms, Inc. and affiliates. -// -// This source code is licensed under the BSD-style license found in -// https://github.com/pytorch/executorch/blob/main/LICENSE -// -// Forked from -// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h #pragma once -#include "mlx/dtype.h" +#include -#include +#include "mlx/dtype.h" +#include "mlx/utils.h" namespace mlx::core { // Return string representation of dtype. const char* dtype_to_string(Dtype arg); -// Macros that iterate across different subsets of Dtypes. -// -// For all of these macros, the final `_` parameter is the name of another macro -// that takes two parameters: the name of a C type, and the name of the -// corresponding Dtype enumerator. -// -// Note that these macros should use fully-qualified namespaces (starting with -// `::`) to ensure that they can be called safely in any arbitrary namespace. -#define MLX_FORALL_INT_TYPES(_) \ - _(uint8_t, uint8) \ - _(uint16_t, uint16) \ - _(uint32_t, uint32) \ - _(uint64_t, uint64) \ - _(int8_t, int8) \ - _(int16_t, int16) \ - _(int32_t, int32) \ - _(int64_t, int64) +#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \ + case DTYPE: \ + f(type_identity{}); \ + break -#define MLX_FORALL_FLOAT_TYPES(_) \ - _(float16_t, float16) \ - _(float, float32) \ - _(double, float64) \ - _(bfloat16_t, bfloat16) +#define MLX_INTERNAL_DTYPE_SWITCH_INTS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t) -// Calls the provided macro on every Dtype, providing the C type and the -// Dtype name to each call. -// -// @param _ A macro that takes two parameters: the name of a C type, and the -// name of the corresponding Dtype enumerator. -#define MLX_FORALL_DTYPES(_) \ - MLX_FORALL_INT_TYPES(_) \ - MLX_FORALL_FLOAT_TYPES(_) \ - _(bool, bool_) \ - _(complex64_t, complex64) +#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double) -// Maps Dtypes to C++ types. -template -struct DtypeToCppType; - -#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \ - template <> \ - struct DtypeToCppType { \ - using type = CPP_TYPE; \ - }; - -MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType) - -#undef SPECIALIZE_DtypeToCppType - -// Maps C++ types to Dtypes. +// This already exists in C++20 but in C++20 we can also just use templated +// lambdas which will make this so much nicer. template -struct CppTypeToDtype; +struct type_identity { + using type = T; +}; -#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \ - template <> \ - struct CppTypeToDtype \ - : std::integral_constant {}; +#define MLX_GET_TYPE(x) typename decltype(x)::type +#define MLX_GET_VALUE(x) decltype(x)::value -MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype) - -#undef SPECIALIZE_CppTypeToDtype - -// Helper macros for switch case macros (see below) -// -// These macros are not meant to be used directly. They provide an easy way to -// generate a switch statement that can handle subsets of Dtypes supported. - -#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \ - case enum_type: { \ - using CTYPE_ALIAS = ::mlx::core::DtypeToCppType::type; \ - __VA_ARGS__; \ - break; \ +template +void dispatch_all_types(Dtype dt, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); } +} -#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \ - switch (TYPE) { \ - __VA_ARGS__ \ - default: \ - throw std::invalid_argument(fmt::format( \ - "Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \ +template +void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + default: + std::ostringstream msg; + msg << tag << " Only integer types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); } +} -#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__) +template +void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only float types supported but " << dt << " was provided"; + throw std::invalid_argument(msg.str()); + } +} -#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__) +template +void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only integer and float types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} -#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE( \ - ::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__) - -#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \ - MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__) - -// Switch case macros -// -// These macros provide an easy way to generate switch statements that apply a -// common lambda function to subsets of Dtypes supported by MLX. -// The lambda function can type specialize to the ctype associated with the -// Dtype being handled through an alias passed as the CTYPE_ALIAS argument. -// -// Arguments: -// - ADDITIONAL: Additional Dtype case to add -// - TYPE: The Dtype to handle through the switch statement -// - NAME: A name for this operation which will be used in error messages -// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype. -// - ...: A statement to be applied to each Dtype case -// -// An example usage is: -// -// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, { -// output.data[0] = input.data[0]; -// }); -// -// Note that these can be nested as well: -// -// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, { -// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, { -// output.data[0] = input.data[0]; -// }); -// }); -// -// These macros are adapted from Dispatch.h in the ATen library. The primary -// difference is that the CTYPE_ALIAS argument is exposed to users, which is -// used to alias the ctype associated with the Dtype that is being handled. - -#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \ - switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) } - -#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)) - -#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \ - MLX_INTERNAL_SWITCH_CHECKED( \ - TYPE, \ - NAME, \ - MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__)) +template +void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only real numbers supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} } // namespace mlx::core diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 61b9da3a2..e53a7a97f 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -253,7 +253,9 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { std::ostream& operator<<(std::ostream& os, array a) { a.eval(); - MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array(os, a)); + dispatch_all_types(a.dtype(), [&](auto type_tag) { + print_array(os, a); + }); return os; } @@ -321,8 +323,9 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) { } iinfo::iinfo(Dtype dtype) : dtype(dtype) { - MLX_SWITCH_INT_TYPES_CHECKED( - dtype, "[iinfo]", CTYPE, set_iinfo_limits(min, max)); + dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) { + set_iinfo_limits(min, max); + }); } } // namespace mlx::core