Move all type switching to templates

This commit is contained in:
Angelos Katharopoulos
2025-06-29 04:04:15 -07:00
parent 45c43dd24a
commit ef813b6d13
19 changed files with 474 additions and 431 deletions

View File

@@ -155,25 +155,33 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) { dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4; constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides()); cuda::ceil_div(axis_size, N_READS), [&](auto block_dim_constant) {
dim3 block_dims{BLOCK_DIM, 1, 1}; dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
auto kernel = dim3 block_dims{block_dim_constant(), 1, 1};
cu::arg_reduce_general<T, cu::ArgMax<T>, BLOCK_DIM, N_READS>; auto kernel = cu::arg_reduce_general<
if (reduce_type_ == ArgReduce::ArgMin) { T,
kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, BLOCK_DIM, N_READS>; cu::ArgMax<T>,
} block_dim_constant(),
kernel<<<num_blocks, block_dims, 0, stream>>>( N_READS>;
in.data<T>(), if (reduce_type_ == ArgReduce::ArgMin) {
out.data<uint32_t>(), kernel = cu::arg_reduce_general<
out.size(), T,
const_param(shape), cu::ArgMin<T>,
const_param(in_strides), block_dim_constant(),
const_param(out_strides), N_READS>;
ndim, }
axis_stride, kernel<<<num_blocks, block_dims, 0, stream>>>(
axis_size); in.data<T>(),
}); out.data<uint32_t>(),
out.size(),
const_param(shape),
const_param(in_strides),
const_param(out_strides),
ndim,
axis_stride,
axis_size);
});
}); });
}); });
} }

View File

@@ -149,47 +149,55 @@ void binary_op_gpu_inplace(
using OutType = cuda_type_t<CTYPE_OUT>; using OutType = cuda_type_t<CTYPE_OUT>;
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out); dispatch_bool(
auto& a_strides = strides[0]; a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
auto& b_strides = strides[1]; out.data_size() > INT32_MAX,
bool large = a.data_size() > INT32_MAX || [&](auto large) {
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
MLX_SWITCH_BOOL(large, LARGE, { Shape shape;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; std::vector<Strides> strides;
int ndim = shape.size(); std::tie(shape, strides) =
if (ndim <= 3) { collapse_contiguous_dims(a, b, out);
MLX_SWITCH_1_2_3(ndim, NDIM, { auto& a_strides = strides[0];
auto kernel = auto& b_strides = strides[1];
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>; int ndim = shape.size();
auto [num_blocks, block_dims] = if (ndim <= 3) {
get_launch_args(kernel, out, large); dispatch_1_2_3(ndim, [&](auto dims_constant) {
kernel<<<num_blocks, block_dims, 0, stream>>>( auto kernel = cu::binary_g_nd<
a.data<InType>(), Op,
b.data<InType>(), InType,
out.data<OutType>(), OutType,
out.size(), IdxT,
const_param<NDIM>(shape), dims_constant()>;
const_param<NDIM>(a_strides), auto [num_blocks, block_dims] =
const_param<NDIM>(b_strides)); get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
}); });
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out.data<OutType>(),
out.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else { } else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>; auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>; kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
@@ -199,7 +207,7 @@ void binary_op_gpu_inplace(
kernel = cu::binary_vv<Op, InType, OutType, IdxT>; kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel, out.data_size(), out.shape(), out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),

View File

@@ -148,49 +148,54 @@ void binary_op_gpu_inplace(
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
if (bopt == BinaryOpType::General) { if (bopt == BinaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a); dispatch_bool(
auto& a_strides = strides[0]; a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
auto& b_strides = strides[1]; out_a.data_size() > INT32_MAX,
bool large = a.data_size() > INT32_MAX || [&](auto large) {
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX; using IdxT = std::conditional_t<large(), int64_t, int32_t>;
MLX_SWITCH_BOOL(large, LARGE, { auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; auto& a_strides = strides[0];
int ndim = shape.size(); auto& b_strides = strides[1];
if (ndim <= 3) { int ndim = shape.size();
MLX_SWITCH_1_2_3(ndim, NDIM, { if (ndim <= 3) {
auto kernel = dispatch_1_2_3(ndim, [&](auto dims_constant) {
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>; auto kernel = cu::binary_g_nd<
auto [num_blocks, block_dims] = Op,
get_launch_args(kernel, out_a, large); InType,
kernel<<<num_blocks, block_dims, 0, stream>>>( OutType,
a.data<InType>(), IdxT,
b.data<InType>(), dims_constant()>;
out_a.data<OutType>(), auto [num_blocks, block_dims] =
out_b.data<OutType>(), get_launch_args(kernel, out_a, large());
out_a.size(), kernel<<<num_blocks, block_dims, 0, stream>>>(
const_param<NDIM>(shape), a.data<InType>(),
const_param<NDIM>(a_strides), b.data<InType>(),
const_param<NDIM>(b_strides)); out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides));
});
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
}); });
} else {
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(),
b.data<InType>(),
out_a.data<OutType>(),
out_b.data<OutType>(),
out_a.size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
ndim);
}
});
} else { } else {
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, { dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>; auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
if (bopt == BinaryOpType::ScalarVector) { if (bopt == BinaryOpType::ScalarVector) {
kernel = cu::binary_sv<Op, InType, OutType, IdxT>; kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
@@ -204,7 +209,7 @@ void binary_op_gpu_inplace(
out_a.data_size(), out_a.data_size(),
out_a.shape(), out_a.shape(),
out_a.strides(), out_a.strides(),
LARGE); large());
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<InType>(), a.data<InType>(),
b.data<InType>(), b.data<InType>(),

View File

@@ -38,16 +38,16 @@ void copy_contiguous(
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::copy_s<InType, OutType, IdxT>; auto kernel = cu::copy_s<InType, OutType, IdxT>;
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
kernel = cu::copy_v<InType, OutType, IdxT>; kernel = cu::copy_v<InType, OutType, IdxT>;
} }
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel, out.data_size(), out.shape(), out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
in.data<InType>() + in_offset, in.data<InType>() + in_offset,
out.data<OutType>() + out_offset, out.data<OutType>() + out_offset,

View File

@@ -58,44 +58,46 @@ void copy_general(
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; dispatch_bool(
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
const InType* in_ptr = in.data<InType>() + offset_in; [&](auto large) {
OutType* out_ptr = out.data<OutType>() + offset_out; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
MLX_SWITCH_BOOL(large, LARGE, { using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
int ndim = shape.size(); OutType* out_ptr = out.data<OutType>() + offset_out;
size_t data_size = 1; int ndim = shape.size();
for (auto& s : shape) size_t data_size = 1;
data_size *= s; for (auto& s : shape)
if (ndim <= 3) { data_size *= s;
MLX_SWITCH_1_2_3(ndim, NDIM, { if (ndim <= 3) {
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>; dispatch_1_2_3(ndim, [&](auto ndim_constant) {
auto [num_blocks, block_dims] = get_launch_args( auto kernel =
kernel, data_size, shape, out.strides(), large); cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
kernel<<<num_blocks, block_dims, 0, stream>>>( auto [num_blocks, block_dims] = get_launch_args(
in_ptr, kernel, data_size, shape, out.strides(), large());
out_ptr, kernel<<<num_blocks, block_dims, 0, stream>>>(
data_size, in_ptr,
const_param<NDIM>(shape), out_ptr,
const_param<NDIM>(strides_in), data_size,
const_param<NDIM>(strides_out)); const_param<ndim_constant()>(shape),
const_param<ndim_constant()>(strides_in),
const_param<ndim_constant()>(strides_out));
});
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(
kernel, data_size, shape, out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
}); });
} else { // ndim >= 4
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, data_size, shape, out.strides(), large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
data_size,
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim);
}
});
}); });
}); });
}); });

View File

@@ -64,44 +64,50 @@ void copy_general_dynamic(
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; dispatch_bool(
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
const InType* in_ptr = in.data<InType>() + offset_in; [&](auto large) {
OutType* out_ptr = out.data<OutType>() + offset_out; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
MLX_SWITCH_BOOL(large, LARGE, { using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
int ndim = shape.size(); OutType* out_ptr = out.data<OutType>() + offset_out;
if (ndim <= 3) { int ndim = shape.size();
MLX_SWITCH_1_2_3(ndim, NDIM, { if (ndim <= 3) {
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>; dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = auto kernel = cu::copy_gg_dynamic_nd<
get_launch_args(kernel, out, large); InType,
kernel<<<num_blocks, block_dims, 0, stream>>>( OutType,
in_ptr, IdxT,
out_ptr, dims_constant()>;
out.size(), auto [num_blocks, block_dims] =
const_param<NDIM>(shape), get_launch_args(kernel, out, large());
const_param<NDIM>(strides_in), kernel<<<num_blocks, block_dims, 0, stream>>>(
const_param<NDIM>(strides_out), in_ptr,
dynamic_offset_in.data<int64_t>(), out_ptr,
dynamic_offset_out.data<int64_t>()); out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in),
const_param<dims_constant()>(strides_out),
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
});
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
}); });
} else { // ndim >= 4
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
const_param(strides_out),
ndim,
dynamic_offset_in.data<int64_t>(),
dynamic_offset_out.data<int64_t>());
}
});
}); });
}); });
}); });

View File

@@ -53,38 +53,41 @@ void copy_general_input(
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto in_type_tag) { dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
dispatch_all_types(out.dtype(), [&](auto out_type_tag) { dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>; dispatch_bool(
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>; in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
const InType* in_ptr = in.data<InType>() + offset_in; [&](auto large) {
OutType* out_ptr = out.data<OutType>() + offset_out; using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX; using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
MLX_SWITCH_BOOL(large, LARGE, { using IdxT = std::conditional_t<large(), int64_t, int32_t>;
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; const InType* in_ptr = in.data<InType>() + offset_in;
int ndim = shape.size(); OutType* out_ptr = out.data<OutType>() + offset_out;
if (ndim <= 3) { int ndim = shape.size();
MLX_SWITCH_1_2_3(ndim, NDIM, { if (ndim <= 3) {
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>; dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = auto kernel =
get_launch_args(kernel, out, large); cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
kernel<<<num_blocks, block_dims, 0, stream>>>( auto [num_blocks, block_dims] =
in_ptr, get_launch_args(kernel, out, large());
out_ptr, kernel<<<num_blocks, block_dims, 0, stream>>>(
out.size(), in_ptr,
const_param<NDIM>(shape), out_ptr,
const_param<NDIM>(strides_in)); out.size(),
const_param<dims_constant()>(shape),
const_param<dims_constant()>(strides_in));
});
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);
}
}); });
} else { // ndim >= 4
auto kernel = cu::copy_g<InType, OutType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
in_ptr,
out_ptr,
out.size(),
const_param(shape),
const_param(strides_in),
ndim);
}
});
}); });
}); });
}); });

View File

@@ -6,6 +6,8 @@
#pragma once #pragma once
#include <type_traits>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/cuda/device/utils.cuh" #include "mlx/backend/cuda/device/utils.cuh"
@@ -17,60 +19,46 @@
namespace mlx::core { namespace mlx::core {
// Convert a number between 1~3 to constexpr. template <typename F>
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \ void dispatch_1_2_3(int n, F&& f) {
switch (N) { \ switch (n) {
case 1: { \ case 1:
constexpr int NDIM = 1; \ f(std::integral_constant<int, 1>{});
__VA_ARGS__; \ break;
break; \ case 2:
} \ f(std::integral_constant<int, 2>{});
case 2: { \ break;
constexpr int NDIM = 2; \ case 3:
__VA_ARGS__; \ f(std::integral_constant<int, 3>{});
break; \ break;
} \
case 3: { \
constexpr int NDIM = 3; \
__VA_ARGS__; \
break; \
} \
} }
}
// Like MLX_SWITCH_ALL_TYPES but for booleans. template <typename F>
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \ void dispatch_bool(bool v, F&& f) {
if (BOOL) { \ if (v) {
constexpr bool BOOL_ALIAS = true; \ f(std::true_type{});
__VA_ARGS__; \ } else {
} else { \ f(std::false_type{});
constexpr bool BOOL_ALIAS = false; \
__VA_ARGS__; \
} }
}
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2. template <typename F>
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \ void dispatch_block_dim(int threads, F&& f) {
{ \ if (threads <= WARP_SIZE) {
uint32_t _num_threads = NUM_THREADS; \ f(std::integral_constant<int, WARP_SIZE>{});
if (_num_threads <= WARP_SIZE) { \ } else if (threads <= WARP_SIZE * 2) {
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \ f(std::integral_constant<int, WARP_SIZE * 2>{});
__VA_ARGS__; \ } else if (threads <= WARP_SIZE * 4) {
} else if (_num_threads <= WARP_SIZE * 2) { \ f(std::integral_constant<int, WARP_SIZE * 4>{});
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \ } else if (threads <= WARP_SIZE * 8) {
__VA_ARGS__; \ f(std::integral_constant<int, WARP_SIZE * 8>{});
} else if (_num_threads <= WARP_SIZE * 4) { \ } else if (threads <= WARP_SIZE * 16) {
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \ f(std::integral_constant<int, WARP_SIZE * 16>{});
__VA_ARGS__; \ } else {
} else if (_num_threads <= WARP_SIZE * 8) { \ f(std::integral_constant<int, WARP_SIZE * 32>{});
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 16) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
__VA_ARGS__; \
} else { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
__VA_ARGS__; \
} \
} }
}
// Maps CPU types to CUDA types. // Maps CPU types to CUDA types.
template <typename T> template <typename T>

View File

@@ -260,20 +260,21 @@ void LayerNorm::eval_gpu(
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) { dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4; constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>; cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
x.data<DataType>(), auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
w.data<DataType>(), kernel<<<n_rows, block_dim(), 0, stream>>>(
b.data<DataType>(), x.data<DataType>(),
out.data<DataType>(), w.data<DataType>(),
eps_, b.data<DataType>(),
axis_size, out.data<DataType>(),
w_stride, eps_,
b_stride); axis_size,
}); w_stride,
b_stride);
});
}); });
}); });
} }
@@ -358,21 +359,26 @@ void LayerNormVJP::eval_gpu(
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) { dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; dispatch_bool(has_w, [&](auto has_w_constant) {
constexpr int N_READS = 4; constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, { dispatch_block_dim(
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( auto kernel = cu::layer_norm_vjp<
x.data<DataType>(), DataType,
w.data<DataType>(), has_w_constant(),
g.data<DataType>(), block_dim(),
gx.data<DataType>(), N_READS>;
gw_temp.data<DataType>(), kernel<<<n_rows, block_dim(), 0, stream>>>(
eps_, x.data<DataType>(),
axis_size, w.data<DataType>(),
w_stride); g.data<DataType>(),
}); gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
}); });
}); });
}); });

View File

@@ -145,13 +145,14 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) { dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 4; constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>; cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
in.data<DataType>(), out.data<DataType>(), axis_size); auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
}); kernel<<<n_rows, block_dim(), 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
}); });
}); });
} }

View File

@@ -112,7 +112,8 @@ void all_reduce(
encoder.set_output_array(intermediate); encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(dt, [&](auto type_tag) { dispatch_all_types(dt, [&](auto type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>; auto kernel = cu::all_reduce<T, U, OP, N_READS>;
@@ -136,7 +137,8 @@ void all_reduce(
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(dt, [&](auto type_tag) { dispatch_all_types(dt, [&](auto type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>; auto kernel = cu::all_reduce<T, U, OP, N_READS>;

View File

@@ -216,10 +216,10 @@ void col_reduce_looped(
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag); dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<CTYPE>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh) // Cub doesn't like const pointers for vectorized loads. (sigh)
@@ -230,7 +230,8 @@ void col_reduce_looped(
constexpr int BN = 32; constexpr int BN = 32;
dim3 grid = output_grid_for_col_reduce(out, args, BN); dim3 grid = output_grid_for_col_reduce(out, args, BN);
int blocks = BM * BN / N_READS; int blocks = BM * BN / N_READS;
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>; auto kernel =
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args); kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
}); });
}); });

View File

@@ -34,7 +34,8 @@ void init_reduce(
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
auto kernel = cu::init_reduce<T, U, OP>; auto kernel = cu::init_reduce<T, U, OP>;

View File

@@ -1,5 +1,7 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include <type_traits>
#include "mlx/backend/common/reduce.h" #include "mlx/backend/common/reduce.h"
#include "mlx/backend/cuda/device/cucomplex_math.cuh" #include "mlx/backend/cuda/device/cucomplex_math.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/kernel_utils.cuh"
@@ -9,43 +11,35 @@
namespace mlx::core { namespace mlx::core {
// Dispatch dynamic ndim to constexpr. template <typename F>
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file. void dispatch_reduce_ndim(int ndim, F&& f) {
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \ if (ndim == 1) {
if (ndim == 1) { \ f(std::integral_constant<int, 1>{});
constexpr uint32_t NDIM = 1; \ } else if (ndim == 2) {
__VA_ARGS__; \ f(std::integral_constant<int, 2>{});
} else if (ndim == 2) { \ } else {
constexpr uint32_t NDIM = 2; \ f(std::integral_constant<int, 5>{});
__VA_ARGS__; \
} else { \
constexpr uint32_t NDIM = 5; \
__VA_ARGS__; \
} }
}
// Dispatch reduce ops to constexpr. template <typename F>
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \ void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) {
if (REDUCE == Reduce::ReduceType::And) { \ if (reduce_type == Reduce::ReduceType::And) {
using OP = cu::And; \ f(type_identity<cu::And>{});
__VA_ARGS__; \ } else if (reduce_type == Reduce::ReduceType::Or) {
} else if (REDUCE == Reduce::ReduceType::Or) { \ f(type_identity<cu::Or>{});
using OP = cu::Or; \ } else if (reduce_type == Reduce::ReduceType::Sum) {
__VA_ARGS__; \ f(type_identity<cu::Sum>{});
} else if (REDUCE == Reduce::ReduceType::Sum) { \ } else if (reduce_type == Reduce::ReduceType::Prod) {
using OP = cu::Sum; \ f(type_identity<cu::Prod>{});
__VA_ARGS__; \ } else if (reduce_type == Reduce::ReduceType::Max) {
} else if (REDUCE == Reduce::ReduceType::Prod) { \ f(type_identity<cu::Max>{});
using OP = cu::Prod; \ } else if (reduce_type == Reduce::ReduceType::Min) {
__VA_ARGS__; \ f(type_identity<cu::Min>{});
} else if (REDUCE == Reduce::ReduceType::Max) { \ } else {
using OP = cu::Max; \ throw std::invalid_argument("Unknown reduce type.");
__VA_ARGS__; \
} else if (REDUCE == Reduce::ReduceType::Min) { \
using OP = cu::Min; \
__VA_ARGS__; \
} else { \
throw std::invalid_argument("Unknown reduce type."); \
} }
}
void all_reduce( void all_reduce(
cu::CommandEncoder& encoder, cu::CommandEncoder& encoder,

View File

@@ -247,9 +247,9 @@ void row_reduce_simple(
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag); dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<CTYPE>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh) // Cub doesn't like const pointers for vectorized loads. (sigh)
@@ -295,9 +295,9 @@ void row_reduce_looped(
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_all_types(in.dtype(), [&](auto type_tag) { dispatch_all_types(in.dtype(), [&](auto type_tag) {
using CTYPE = MLX_GET_TYPE(type_tag); dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { using OP = MLX_GET_TYPE(reduce_type_tag);
using T = cuda_type_t<CTYPE>; using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
using U = typename cu::ReduceResult<OP, T>::type; using U = typename cu::ReduceResult<OP, T>::type;
// Cub doesn't like const pointers for vectorized loads. (sigh) // Cub doesn't like const pointers for vectorized loads. (sigh)
@@ -313,10 +313,16 @@ void row_reduce_looped(
// Pick the kernel // Pick the kernel
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>; auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, { dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
MLX_SWITCH_BLOCK_DIM(threads, THREADS, { dispatch_block_dim(threads, [&](auto threads_constant) {
kernel = cu::row_reduce_looped<T, U, OP, NDIM, THREADS, N_READS>; kernel = cu::row_reduce_looped<
block.x = THREADS; T,
U,
OP,
reduce_ndim(),
threads_constant(),
N_READS>;
block.x = threads_constant();
}); });
}); });

View File

@@ -226,18 +226,19 @@ void RMSNorm::eval_gpu(
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) { dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr uint32_t N_READS = 4; constexpr uint32_t N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>; cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
x.data<DataType>(), auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
w.data<DataType>(), kernel<<<n_rows, block_dim(), 0, stream>>>(
out.data<DataType>(), x.data<DataType>(),
eps_, w.data<DataType>(),
axis_size, out.data<DataType>(),
w_stride); eps_,
}); axis_size,
w_stride);
});
}); });
}); });
} }
@@ -312,21 +313,27 @@ void RMSNormVJP::eval_gpu(
encoder.set_output_array(gw_temp); encoder.set_output_array(gw_temp);
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) { encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) { dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; dispatch_bool(has_w, [&](auto has_w_constant) {
constexpr int N_READS = 4; constexpr int N_READS = 4;
MLX_SWITCH_BOOL(has_w, HAS_W, { dispatch_block_dim(
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>; using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( constexpr int N_READS = 4;
x.data<DataType>(), auto kernel = cu::rms_norm_vjp<
w.data<DataType>(), DataType,
g.data<DataType>(), has_w_constant(),
gx.data<DataType>(), block_dim(),
gw_temp.data<DataType>(), N_READS>;
eps_, kernel<<<n_rows, block_dim(), 0, stream>>>(
axis_size, x.data<DataType>(),
w_stride); w.data<DataType>(),
}); g.data<DataType>(),
gx.data<DataType>(),
gw_temp.data<DataType>(),
eps_,
axis_size,
w_stride);
});
}); });
}); });
}); });

View File

@@ -311,11 +311,11 @@ void RoPE::eval_gpu(
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>; dispatch_bool(traditional_, [&](auto traditional) {
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, { dispatch_bool(forward_, [&](auto forward) {
MLX_SWITCH_BOOL(forward_, FORWARD, { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (single && !with_freqs) { if (single && !with_freqs) {
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>; auto kernel = cu::rope_single<DataType, traditional(), forward()>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>( kernel<<<grid, block, 0, stream>>>(
@@ -327,7 +327,8 @@ void RoPE::eval_gpu(
mat_size, mat_size,
dims); dims);
} else if (single) { } else if (single) {
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>; auto kernel =
cu::rope_single_freqs<DataType, traditional(), forward()>;
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size); uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
kernel<<<grid, block, 0, stream>>>( kernel<<<grid, block, 0, stream>>>(
@@ -340,7 +341,7 @@ void RoPE::eval_gpu(
dims, dims,
inputs[2].strides(0)); inputs[2].strides(0));
} else if (with_freqs) { } else if (with_freqs) {
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>; auto kernel = cu::rope_freqs<DataType, traditional(), forward()>;
uint3 dims = uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4; dims.z = (dims.z + 3) / 4;
@@ -358,7 +359,7 @@ void RoPE::eval_gpu(
dims, dims,
inputs[2].strides(0)); inputs[2].strides(0));
} else { } else {
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>; auto kernel = cu::rope<DataType, traditional(), forward()>;
uint3 dims = uint3 dims =
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size); make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
dims.z = (dims.z + 3) / 4; dims.z = (dims.z + 3) / 4;

View File

@@ -143,16 +143,17 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.set_output_array(out); encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) { encoder.launch_kernel([&](cudaStream_t stream) {
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) { dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
constexpr int N_READS = 4; constexpr int N_READS = 4;
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, { dispatch_block_dim(
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>; cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
if (precise) { using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>; auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
} if (precise) {
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>( kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
in.data<DataType>(), out.data<DataType>(), axis_size); }
}); kernel<<<n_rows, block_dim(), 0, stream>>>(
in.data<DataType>(), out.data<DataType>(), axis_size);
});
}); });
}); });
} }

View File

@@ -97,53 +97,56 @@ void ternary_op_gpu_inplace(
auto topt = get_ternary_op_type(a, b, c); auto topt = get_ternary_op_type(a, b, c);
if (topt == TernaryOpType::General) { if (topt == TernaryOpType::General) {
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out); dispatch_bool(
auto& a_strides = strides[0]; a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
auto& b_strides = strides[1]; c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
auto& c_strides = strides[2]; [&](auto large) {
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX || using IdxT = std::conditional_t<large(), int64_t, int32_t>;
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX; auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
MLX_SWITCH_BOOL(large, LARGE, { auto& a_strides = strides[0];
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>; auto& b_strides = strides[1];
int ndim = shape.size(); auto& c_strides = strides[2];
if (ndim <= 3) { int ndim = shape.size();
MLX_SWITCH_1_2_3(ndim, NDIM, { if (ndim <= 3) {
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>; dispatch_1_2_3(ndim, [&](auto dims_constant) {
auto [num_blocks, block_dims] = auto kernel =
get_launch_args(kernel, out, large); cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
kernel<<<num_blocks, block_dims, 0, stream>>>( auto [num_blocks, block_dims] =
a.data<bool>(), get_launch_args(kernel, out, large());
b.data<DType>(), kernel<<<num_blocks, block_dims, 0, stream>>>(
c.data<DType>(), a.data<bool>(),
out.data<DType>(), b.data<DType>(),
out.size(), c.data<DType>(),
const_param<NDIM>(shape), out.data<DType>(),
const_param<NDIM>(a_strides), out.size(),
const_param<NDIM>(b_strides), const_param<dims_constant()>(shape),
const_param<NDIM>(c_strides)); const_param<dims_constant()>(a_strides),
const_param<dims_constant()>(b_strides),
const_param<dims_constant()>(c_strides));
});
} else {
auto kernel = cu::ternary_g<Op, DType, IdxT>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out, large());
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
const_param(c_strides),
ndim);
}
}); });
} else {
auto kernel = cu::ternary_g<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(),
b.data<DType>(),
c.data<DType>(),
out.data<DType>(),
out.data_size(),
const_param(shape),
const_param(a_strides),
const_param(b_strides),
const_param(c_strides),
ndim);
}
});
} else { } else {
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, { dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>; using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
auto kernel = cu::ternary_v<Op, DType, IdxT>; auto kernel = cu::ternary_v<Op, DType, IdxT>;
auto [num_blocks, block_dims] = get_launch_args( auto [num_blocks, block_dims] = get_launch_args(
kernel, out.data_size(), out.shape(), out.strides(), LARGE); kernel, out.data_size(), out.shape(), out.strides(), large());
kernel<<<num_blocks, block_dims, 0, stream>>>( kernel<<<num_blocks, block_dims, 0, stream>>>(
a.data<bool>(), a.data<bool>(),
b.data<DType>(), b.data<DType>(),