Start changing MLX_SWITCH to templates

This commit is contained in:
Angelos Katharopoulos
2025-06-29 02:29:23 -07:00
parent 772f471ff2
commit 45c43dd24a
25 changed files with 296 additions and 359 deletions

View File

@@ -152,26 +152,19 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
using InType = cuda_type_t<CTYPE>;
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
dim3 block_dims{BLOCK_DIM, 1, 1};
auto kernel = &cu::arg_reduce_general<
InType,
cu::ArgMax<InType>,
BLOCK_DIM,
N_READS>;
auto kernel =
cu::arg_reduce_general<T, cu::ArgMax<T>, BLOCK_DIM, N_READS>;
if (reduce_type_ == ArgReduce::ArgMin) {
kernel = &cu::arg_reduce_general<
InType,
cu::ArgMin<InType>,
BLOCK_DIM,
N_READS>;
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, BLOCK_DIM, N_READS>;
}
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>(),
in.data<T>(),
out.data<uint32_t>(),
out.size(),
const_param(shape),

View File

@@ -140,8 +140,10 @@ void binary_op_gpu_inplace(
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;

View File

@@ -138,8 +138,10 @@ void binary_op_gpu_inplace(
encoder.set_output_array(out_a);
encoder.set_output_array(out_b);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;

View File

@@ -10,15 +10,6 @@
namespace mlx::core {
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
using InType = cuda_type_t<CTYPE_IN>; \
using OutType = cuda_type_t<CTYPE_OUT>; \
__VA_ARGS__; \
}); \
})
void copy_contiguous(
cu::CommandEncoder& encoder,
CopyType ctype,

View File

@@ -36,19 +36,23 @@ void copy_contiguous(
int64_t in_offset,
int64_t out_offset) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::copy_s<InType, OutType, IdxT>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
auto kernel = cu::copy_s<InType, OutType, IdxT>;
if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>;
}
auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset,
out.data<OutType>() + out_offset,
out.data_size());
});
});
});
});

View File

@@ -56,42 +56,46 @@ void copy_general(
const Strides& strides_in,
const Strides& strides_out) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
});
});
});

View File

@@ -62,41 +62,46 @@ void copy_general_dynamic(
const array& dynamic_offset_in,
const array& dynamic_offset_out) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
}
});
});
});
});

View File

@@ -51,35 +51,40 @@ void copy_general_input(
const Shape& shape,
const Strides& strides_in) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param<NDIM>(shape),
const_param<NDIM>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);
}
const_param(shape),
const_param(strides_in),
ndim);
}
});
});
});
});

View File

@@ -259,8 +259,8 @@ void LayerNorm::eval_gpu(
encoder.set_input_array(b);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
@@ -357,8 +357,8 @@ void LayerNormVJP::eval_gpu(
encoder.set_output_array(gx);
encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, {
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {

View File

@@ -144,8 +144,8 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;

View File

@@ -28,7 +28,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s);
encoder.set_output_array(out);
encoder.launch_kernel([&, this](cudaStream_t stream) {
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
using OutType = cuda_type_t<CTYPE>;
CTYPE step =
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);

View File

@@ -111,10 +111,10 @@ void all_reduce(
encoder.add_temporary(intermediate);
encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
dispatch_all_types(dt, [&](auto type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
static_cast<T*>(indata),
@@ -135,10 +135,10 @@ void all_reduce(
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
dispatch_all_types(dt, [&](auto type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
static_cast<T*>(indata), out.data<U>(), block_step, insize);

View File

@@ -215,11 +215,12 @@ void col_reduce_looped(
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());

View File

@@ -33,10 +33,10 @@ void init_reduce(
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::init_reduce<T, U, OP>;
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);

View File

@@ -246,10 +246,11 @@ void row_reduce_simple(
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());
@@ -293,10 +294,11 @@ void row_reduce_looped(
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh)
T* indata = const_cast<T*>(in.data<T>());

View File

@@ -225,8 +225,8 @@ void RMSNorm::eval_gpu(
encoder.set_input_array(w);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>;
@@ -311,8 +311,8 @@ void RMSNormVJP::eval_gpu(
encoder.set_output_array(gx);
encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, {
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {

View File

@@ -310,8 +310,8 @@ void RoPE::eval_gpu(
encoder.set_input_array(offset);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
MLX_SWITCH_BOOL(forward_, FORWARD, {
if (single && !with_freqs) {

View File

@@ -142,8 +142,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
using DataType = cuda_type_t<CTYPE>;
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;

View File

@@ -76,6 +76,14 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
temp.data<void>(), size, args...));
}
struct OffsetTransform {
int nsort;
int __device__ operator()(int i) {
return i * nsort;
}
};
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
array out = out_;
auto& encoder = cu::get_command_encoder(s);
@@ -106,12 +114,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag);
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
using Type = cuda_type_t<CTYPE>;
auto offsets = thrust::make_transform_iterator(
thrust::make_counting_iterator(0),
[nsort] __device__(int i) { return i * nsort; });
thrust::make_counting_iterator(0), OffsetTransform{nsort});
if (argsort) {
// Indices in the sorted dimension.
array indices(

View File

@@ -92,8 +92,8 @@ void ternary_op_gpu_inplace(
encoder.set_input_array(c);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
using DType = cuda_type_t<CTYPE>;
dispatch_all_types(out.dtype(), [&](auto type_tag) {
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) {

View File

@@ -79,8 +79,10 @@ void unary_op_gpu_inplace(
encoder.set_input_array(in);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
using InType = cuda_type_t<CTYPE_IN>;
using OutType = cuda_type_t<CTYPE_OUT>;

View File

@@ -34,13 +34,7 @@ const char* dtype_to_cuda_type(const Dtype& dtype) {
if (dtype == complex64) {
return "cuComplex";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (dtype == DTYPE) { \
return #CPP_TYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return nullptr;
return dtype_to_string(dtype);
}
} // namespace mlx::core

View File

@@ -5,16 +5,37 @@
namespace mlx::core {
const char* dtype_to_string(Dtype arg) {
if (arg == bool_) {
return "bool";
switch (arg) {
case bool_:
return "bool";
case int8:
return "int8";
case int16:
return "int16";
case int32:
return "int32";
case int64:
return "int64";
case uint8:
return "uint8";
case uint16:
return "uint16";
case uint32:
return "uint32";
case uint64:
return "uint64";
case float16:
return "float16";
case bfloat16:
return "bfloat16";
case float32:
return "float32";
case float64:
return "float64";
case complex64:
return "complex64";
}
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
if (DTYPE == arg) { \
return #DTYPE; \
}
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
#undef SPECIALIZE_DtypeToString
return "(unknown)";
return "unknown";
}
} // namespace mlx::core

View File

@@ -1,207 +1,106 @@
// Copyright © 2025 Apple Inc.
// Copyright © Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the BSD-style license found in
// https://github.com/pytorch/executorch/blob/main/LICENSE
//
// Forked from
// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h
#pragma once
#include "mlx/dtype.h"
#include <sstream>
#include <fmt/format.h>
#include "mlx/dtype.h"
#include "mlx/utils.h"
namespace mlx::core {
// Return string representation of dtype.
const char* dtype_to_string(Dtype arg);
// Macros that iterate across different subsets of Dtypes.
//
// For all of these macros, the final `_` parameter is the name of another macro
// that takes two parameters: the name of a C type, and the name of the
// corresponding Dtype enumerator.
//
// Note that these macros should use fully-qualified namespaces (starting with
// `::`) to ensure that they can be called safely in any arbitrary namespace.
#define MLX_FORALL_INT_TYPES(_) \
_(uint8_t, uint8) \
_(uint16_t, uint16) \
_(uint32_t, uint32) \
_(uint64_t, uint64) \
_(int8_t, int8) \
_(int16_t, int16) \
_(int32_t, int32) \
_(int64_t, int64)
#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \
case DTYPE: \
f(type_identity<TYPE>{}); \
break
#define MLX_FORALL_FLOAT_TYPES(_) \
_(float16_t, float16) \
_(float, float32) \
_(double, float64) \
_(bfloat16_t, bfloat16)
#define MLX_INTERNAL_DTYPE_SWITCH_INTS() \
MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t)
// Calls the provided macro on every Dtype, providing the C type and the
// Dtype name to each call.
//
// @param _ A macro that takes two parameters: the name of a C type, and the
// name of the corresponding Dtype enumerator.
#define MLX_FORALL_DTYPES(_) \
MLX_FORALL_INT_TYPES(_) \
MLX_FORALL_FLOAT_TYPES(_) \
_(bool, bool_) \
_(complex64_t, complex64)
#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \
MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double)
// Maps Dtypes to C++ types.
template <Dtype::Val N>
struct DtypeToCppType;
#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \
template <> \
struct DtypeToCppType<Dtype::Val::DTYPE> { \
using type = CPP_TYPE; \
};
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType)
#undef SPECIALIZE_DtypeToCppType
// Maps C++ types to Dtypes.
// This already exists in C++20 but in C++20 we can also just use templated
// lambdas which will make this so much nicer.
template <typename T>
struct CppTypeToDtype;
struct type_identity {
using type = T;
};
#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \
template <> \
struct CppTypeToDtype<CPP_TYPE> \
: std::integral_constant<Dtype::Val, Dtype::Val::DTYPE> {};
#define MLX_GET_TYPE(x) typename decltype(x)::type
#define MLX_GET_VALUE(x) decltype(x)::value
MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype)
#undef SPECIALIZE_CppTypeToDtype
// Helper macros for switch case macros (see below)
//
// These macros are not meant to be used directly. They provide an easy way to
// generate a switch statement that can handle subsets of Dtypes supported.
#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
case enum_type: { \
using CTYPE_ALIAS = ::mlx::core::DtypeToCppType<enum_type>::type; \
__VA_ARGS__; \
break; \
template <typename F>
void dispatch_all_types(Dtype dt, F&& f) {
switch (dt) {
MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);
MLX_INTERNAL_DTYPE_SWITCH_INTS();
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);
}
}
#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \
switch (TYPE) { \
__VA_ARGS__ \
default: \
throw std::invalid_argument(fmt::format( \
"Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \
template <typename F>
void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) {
switch (dt) {
MLX_INTERNAL_DTYPE_SWITCH_INTS();
default:
std::ostringstream msg;
msg << tag << " Only integer types supported but " << dt
<< " was provided";
throw std::invalid_argument(msg.str());
}
}
#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__)
template <typename F>
void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) {
switch (dt) {
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
default:
std::ostringstream msg;
msg << tag << " Only float types supported but " << dt << " was provided";
throw std::invalid_argument(msg.str());
}
}
#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__)
template <typename F>
void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) {
switch (dt) {
MLX_INTERNAL_DTYPE_SWITCH_INTS();
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
default:
std::ostringstream msg;
msg << tag << " Only integer and float types supported but " << dt
<< " was provided";
throw std::invalid_argument(msg.str());
}
}
#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE( \
::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__)
#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)
// Switch case macros
//
// These macros provide an easy way to generate switch statements that apply a
// common lambda function to subsets of Dtypes supported by MLX.
// The lambda function can type specialize to the ctype associated with the
// Dtype being handled through an alias passed as the CTYPE_ALIAS argument.
//
// Arguments:
// - ADDITIONAL: Additional Dtype case to add
// - TYPE: The Dtype to handle through the switch statement
// - NAME: A name for this operation which will be used in error messages
// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype.
// - ...: A statement to be applied to each Dtype case
//
// An example usage is:
//
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, {
// output.data<CTYPE>[0] = input.data<CTYPE>[0];
// });
//
// Note that these can be nested as well:
//
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, {
// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, {
// output.data<CTYPE_OUT>[0] = input.data<CTYPE_IN>[0];
// });
// });
//
// These macros are adapted from Dispatch.h in the ATen library. The primary
// difference is that the CTYPE_ALIAS argument is exposed to users, which is
// used to alias the ctype associated with the Dtype that is being handled.
#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \
switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) }
#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
MLX_INTERNAL_SWITCH_CHECKED( \
TYPE, \
NAME, \
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
template <typename F>
void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) {
switch (dt) {
MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);
MLX_INTERNAL_DTYPE_SWITCH_INTS();
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
default:
std::ostringstream msg;
msg << tag << " Only real numbers supported but " << dt
<< " was provided";
throw std::invalid_argument(msg.str());
}
}
} // namespace mlx::core

View File

@@ -253,7 +253,9 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
std::ostream& operator<<(std::ostream& os, array a) {
a.eval();
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a));
dispatch_all_types(a.dtype(), [&](auto type_tag) {
print_array<MLX_GET_TYPE(type_tag)>(os, a);
});
return os;
}
@@ -321,8 +323,9 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
}
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
MLX_SWITCH_INT_TYPES_CHECKED(
dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max));
dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) {
set_iinfo_limits<MLX_GET_TYPE(type_tag)>(min, max);
});
}
} // namespace mlx::core