MLX_SWITCH macros to templates (#2320)

This commit is contained in:
Angelos Katharopoulos
2025-07-01 01:33:44 -07:00
committed by GitHub
parent 33bf1a244b
commit 3d5e17e507
27 changed files with 693 additions and 692 deletions

View File

@@ -56,42 +56,48 @@ void copy_general(
const Strides& strides_in,
const Strides& strides_out) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
MLX_SWITCH_BOOL(large, LARGE, {
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param<NDIM>(shape),
const_param<NDIM>(strides_in),
const_param<NDIM>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
dispatch_bool(
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
[&](auto large) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
const InType* in_ptr = in.data<InType>() + offset_in;
OutType* out_ptr = out.data<OutType>() + offset_out;
int ndim = shape.size();
size_t data_size = 1;
for (auto& s : shape)
data_size *= s;
if (ndim <= 3) {
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto kernel =
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
});
});
});