mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-05 08:41:13 +08:00
MLX_SWITCH macros to templates (#2320)
This commit is contained in:
parent
33bf1a244b
commit
3d5e17e507
@ -152,26 +152,20 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_REAL_TYPES_CHECKED(in.dtype(), "ArgReduce", CTYPE, {
|
dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
|
||||||
using InType = cuda_type_t<CTYPE>;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
constexpr uint32_t N_READS = 4;
|
constexpr uint32_t N_READS = 4;
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
dispatch_block_dim(
|
||||||
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
dim3 block_dims{BLOCK_DIM, 1, 1};
|
auto kernel =
|
||||||
auto kernel = &cu::arg_reduce_general<
|
cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
|
||||||
InType,
|
|
||||||
cu::ArgMax<InType>,
|
|
||||||
BLOCK_DIM,
|
|
||||||
N_READS>;
|
|
||||||
if (reduce_type_ == ArgReduce::ArgMin) {
|
if (reduce_type_ == ArgReduce::ArgMin) {
|
||||||
kernel = &cu::arg_reduce_general<
|
kernel = cu::
|
||||||
InType,
|
arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
|
||||||
cu::ArgMin<InType>,
|
|
||||||
BLOCK_DIM,
|
|
||||||
N_READS>;
|
|
||||||
}
|
}
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dim(), 0, stream>>>(
|
||||||
in.data<InType>(),
|
in.data<T>(),
|
||||||
out.data<uint32_t>(),
|
out.data<uint32_t>(),
|
||||||
out.size(),
|
out.size(),
|
||||||
const_param(shape),
|
const_param(shape),
|
||||||
|
@ -140,40 +140,50 @@ void binary_op_gpu_inplace(
|
|||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||||
|
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out);
|
dispatch_bool(
|
||||||
|
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||||
|
out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
|
Shape shape;
|
||||||
|
std::vector<Strides> strides;
|
||||||
|
std::tie(shape, strides) =
|
||||||
|
collapse_contiguous_dims(a, b, out);
|
||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
bool large = a.data_size() > INT32_MAX ||
|
|
||||||
b.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto kernel = cu::binary_g_nd<
|
||||||
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
Op,
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out, large);
|
get_launch_args(kernel, out, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out.data<OutType>(),
|
out.data<OutType>(),
|
||||||
out.size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<NDIM>(a_strides),
|
const_param<dims_constant()>(a_strides),
|
||||||
const_param<NDIM>(b_strides));
|
const_param<dims_constant()>(b_strides));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out, large);
|
get_launch_args(kernel, out, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
@ -186,8 +196,8 @@ void binary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||||
@ -197,7 +207,7 @@ void binary_op_gpu_inplace(
|
|||||||
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_vv<Op, InType, OutType, IdxT>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
|
@ -138,42 +138,52 @@ void binary_op_gpu_inplace(
|
|||||||
encoder.set_output_array(out_a);
|
encoder.set_output_array(out_a);
|
||||||
encoder.set_output_array(out_b);
|
encoder.set_output_array(out_b);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE_IN, {
|
dispatch_all_types(a.dtype(), [&](auto in_type_tag) {
|
||||||
MLX_SWITCH_ALL_TYPES(out_a.dtype(), CTYPE_OUT, {
|
dispatch_all_types(out_a.dtype(), [&](auto out_type_tag) {
|
||||||
|
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||||
|
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||||
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
if constexpr (cu::supports_binary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
|
||||||
auto bopt = get_binary_op_type(a, b);
|
auto bopt = get_binary_op_type(a, b);
|
||||||
if (bopt == BinaryOpType::General) {
|
if (bopt == BinaryOpType::General) {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, out_a);
|
dispatch_bool(
|
||||||
|
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||||
|
out_a.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
|
Shape shape;
|
||||||
|
std::vector<Strides> strides;
|
||||||
|
std::tie(shape, strides) =
|
||||||
|
collapse_contiguous_dims(a, b, out_a);
|
||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
bool large = a.data_size() > INT32_MAX ||
|
|
||||||
b.data_size() > INT32_MAX || out_a.data_size() > INT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel =
|
auto kernel = cu::binary_g_nd<
|
||||||
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
|
Op,
|
||||||
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out_a, large);
|
get_launch_args(kernel, out_a, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
out_a.data<OutType>(),
|
out_a.data<OutType>(),
|
||||||
out_b.data<OutType>(),
|
out_b.data<OutType>(),
|
||||||
out_a.size(),
|
out_a.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<NDIM>(a_strides),
|
const_param<dims_constant()>(a_strides),
|
||||||
const_param<NDIM>(b_strides));
|
const_param<dims_constant()>(b_strides));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
auto kernel = cu::binary_g<Op, InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out_a, large);
|
get_launch_args(kernel, out_a, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
@ -187,8 +197,8 @@ void binary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
MLX_SWITCH_BOOL(out_a.data_size() > UINT32_MAX, LARGE, {
|
dispatch_bool(out_a.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
auto kernel = cu::binary_ss<Op, InType, OutType, IdxT>;
|
||||||
if (bopt == BinaryOpType::ScalarVector) {
|
if (bopt == BinaryOpType::ScalarVector) {
|
||||||
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
kernel = cu::binary_sv<Op, InType, OutType, IdxT>;
|
||||||
@ -202,7 +212,7 @@ void binary_op_gpu_inplace(
|
|||||||
out_a.data_size(),
|
out_a.data_size(),
|
||||||
out_a.shape(),
|
out_a.shape(),
|
||||||
out_a.strides(),
|
out_a.strides(),
|
||||||
LARGE);
|
large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<InType>(),
|
a.data<InType>(),
|
||||||
b.data<InType>(),
|
b.data<InType>(),
|
||||||
|
@ -10,15 +10,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
#define MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, ...) \
|
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, { \
|
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, { \
|
|
||||||
using InType = cuda_type_t<CTYPE_IN>; \
|
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
}); \
|
|
||||||
})
|
|
||||||
|
|
||||||
void copy_contiguous(
|
void copy_contiguous(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
CopyType ctype,
|
CopyType ctype,
|
||||||
|
@ -36,15 +36,18 @@ void copy_contiguous(
|
|||||||
int64_t in_offset,
|
int64_t in_offset,
|
||||||
int64_t out_offset) {
|
int64_t out_offset) {
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||||
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
auto kernel = cu::copy_s<InType, OutType, IdxT>;
|
||||||
if (ctype == CopyType::Vector) {
|
if (ctype == CopyType::Vector) {
|
||||||
kernel = cu::copy_v<InType, OutType, IdxT>;
|
kernel = cu::copy_v<InType, OutType, IdxT>;
|
||||||
}
|
}
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in.data<InType>() + in_offset,
|
in.data<InType>() + in_offset,
|
||||||
out.data<OutType>() + out_offset,
|
out.data<OutType>() + out_offset,
|
||||||
@ -52,6 +55,7 @@ void copy_contiguous(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -56,33 +56,38 @@ void copy_general(
|
|||||||
const Strides& strides_in,
|
const Strides& strides_in,
|
||||||
const Strides& strides_out) {
|
const Strides& strides_out) {
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
dispatch_bool(
|
||||||
|
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
size_t data_size = 1;
|
size_t data_size = 1;
|
||||||
for (auto& s : shape)
|
for (auto& s : shape)
|
||||||
data_size *= s;
|
data_size *= s;
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
dispatch_1_2_3(ndim, [&](auto ndim_constant) {
|
||||||
auto kernel = cu::copy_gg_nd<InType, OutType, IdxT, NDIM>;
|
auto kernel =
|
||||||
auto [num_blocks, block_dims] =
|
cu::copy_gg_nd<InType, OutType, IdxT, ndim_constant()>;
|
||||||
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
|
kernel, data_size, shape, out.strides(), large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
data_size,
|
data_size,
|
||||||
const_param<NDIM>(shape),
|
const_param<ndim_constant()>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<ndim_constant()>(strides_in),
|
||||||
const_param<NDIM>(strides_out));
|
const_param<ndim_constant()>(strides_out));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
auto kernel = cu::copy_gg<InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
get_launch_args(kernel, data_size, shape, out.strides(), large);
|
kernel, data_size, shape, out.strides(), large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
@ -95,6 +100,7 @@ void copy_general(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -62,30 +62,40 @@ void copy_general_dynamic(
|
|||||||
const array& dynamic_offset_in,
|
const array& dynamic_offset_in,
|
||||||
const array& dynamic_offset_out) {
|
const array& dynamic_offset_out) {
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
dispatch_bool(
|
||||||
|
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::copy_gg_dynamic_nd<InType, OutType, IdxT, NDIM>;
|
auto kernel = cu::copy_gg_dynamic_nd<
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
InType,
|
||||||
|
OutType,
|
||||||
|
IdxT,
|
||||||
|
dims_constant()>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<NDIM>(strides_in),
|
const_param<dims_constant()>(strides_in),
|
||||||
const_param<NDIM>(strides_out),
|
const_param<dims_constant()>(strides_out),
|
||||||
dynamic_offset_in.data<int64_t>(),
|
dynamic_offset_in.data<int64_t>(),
|
||||||
dynamic_offset_out.data<int64_t>());
|
dynamic_offset_out.data<int64_t>());
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
auto kernel = cu::copy_gg_dynamic<InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
@ -100,6 +110,7 @@ void copy_general_dynamic(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -51,27 +51,34 @@ void copy_general_input(
|
|||||||
const Shape& shape,
|
const Shape& shape,
|
||||||
const Strides& strides_in) {
|
const Strides& strides_in) {
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_COPY_TYPES(in, out, InType, OutType, {
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
dispatch_bool(
|
||||||
|
in.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using InType = cuda_type_t<MLX_GET_TYPE(in_type_tag)>;
|
||||||
|
using OutType = cuda_type_t<MLX_GET_TYPE(out_type_tag)>;
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
const InType* in_ptr = in.data<InType>() + offset_in;
|
const InType* in_ptr = in.data<InType>() + offset_in;
|
||||||
OutType* out_ptr = out.data<OutType>() + offset_out;
|
OutType* out_ptr = out.data<OutType>() + offset_out;
|
||||||
bool large = in.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::copy_g_nd<InType, OutType, IdxT, NDIM>;
|
auto kernel =
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
cu::copy_g_nd<InType, OutType, IdxT, dims_constant()>;
|
||||||
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
out.size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<NDIM>(strides_in));
|
const_param<dims_constant()>(strides_in));
|
||||||
});
|
});
|
||||||
} else { // ndim >= 4
|
} else { // ndim >= 4
|
||||||
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
auto kernel = cu::copy_g<InType, OutType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
in_ptr,
|
in_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
@ -83,6 +90,7 @@ void copy_general_input(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -6,6 +6,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/backend/cuda/device/utils.cuh"
|
#include "mlx/backend/cuda/device/utils.cuh"
|
||||||
|
|
||||||
@ -17,60 +19,46 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Convert a number between 1~3 to constexpr.
|
template <typename F>
|
||||||
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
|
void dispatch_1_2_3(int n, F&& f) {
|
||||||
switch (N) { \
|
switch (n) {
|
||||||
case 1: { \
|
case 1:
|
||||||
constexpr int NDIM = 1; \
|
f(std::integral_constant<int, 1>{});
|
||||||
__VA_ARGS__; \
|
break;
|
||||||
break; \
|
case 2:
|
||||||
} \
|
f(std::integral_constant<int, 2>{});
|
||||||
case 2: { \
|
break;
|
||||||
constexpr int NDIM = 2; \
|
case 3:
|
||||||
__VA_ARGS__; \
|
f(std::integral_constant<int, 3>{});
|
||||||
break; \
|
break;
|
||||||
} \
|
|
||||||
case 3: { \
|
|
||||||
constexpr int NDIM = 3; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
break; \
|
|
||||||
} \
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Like MLX_SWITCH_ALL_TYPES but for booleans.
|
template <typename F>
|
||||||
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
|
void dispatch_bool(bool v, F&& f) {
|
||||||
if (BOOL) { \
|
if (v) {
|
||||||
constexpr bool BOOL_ALIAS = true; \
|
f(std::true_type{});
|
||||||
__VA_ARGS__; \
|
} else {
|
||||||
} else { \
|
f(std::false_type{});
|
||||||
constexpr bool BOOL_ALIAS = false; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
template <typename F>
|
||||||
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
void dispatch_block_dim(int threads, F&& f) {
|
||||||
{ \
|
if (threads <= WARP_SIZE) {
|
||||||
uint32_t _num_threads = NUM_THREADS; \
|
f(std::integral_constant<int, WARP_SIZE>{});
|
||||||
if (_num_threads <= WARP_SIZE) { \
|
} else if (threads <= WARP_SIZE * 2) {
|
||||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
f(std::integral_constant<int, WARP_SIZE * 2>{});
|
||||||
__VA_ARGS__; \
|
} else if (threads <= WARP_SIZE * 4) {
|
||||||
} else if (_num_threads <= WARP_SIZE * 2) { \
|
f(std::integral_constant<int, WARP_SIZE * 4>{});
|
||||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
} else if (threads <= WARP_SIZE * 8) {
|
||||||
__VA_ARGS__; \
|
f(std::integral_constant<int, WARP_SIZE * 8>{});
|
||||||
} else if (_num_threads <= WARP_SIZE * 4) { \
|
} else if (threads <= WARP_SIZE * 16) {
|
||||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
f(std::integral_constant<int, WARP_SIZE * 16>{});
|
||||||
__VA_ARGS__; \
|
} else {
|
||||||
} else if (_num_threads <= WARP_SIZE * 8) { \
|
f(std::integral_constant<int, WARP_SIZE * 32>{});
|
||||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
} else if (_num_threads <= WARP_SIZE * 16) { \
|
|
||||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
} else { \
|
|
||||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
} \
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Maps CPU types to CUDA types.
|
// Maps CPU types to CUDA types.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -259,12 +259,13 @@ void LayerNorm::eval_gpu(
|
|||||||
encoder.set_input_array(b);
|
encoder.set_input_array(b);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "layernorm", CTYPE, {
|
dispatch_float_types(out.dtype(), "layernorm", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<CTYPE>;
|
|
||||||
constexpr uint32_t N_READS = 4;
|
constexpr uint32_t N_READS = 4;
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
dispatch_block_dim(
|
||||||
auto kernel = cu::layer_norm<DataType, BLOCK_DIM, N_READS>;
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
auto kernel = cu::layer_norm<DataType, block_dim(), N_READS>;
|
||||||
|
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
b.data<DataType>(),
|
b.data<DataType>(),
|
||||||
@ -357,13 +358,18 @@ void LayerNormVJP::eval_gpu(
|
|||||||
encoder.set_output_array(gx);
|
encoder.set_output_array(gx);
|
||||||
encoder.set_output_array(gw_temp);
|
encoder.set_output_array(gw_temp);
|
||||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "layernorm_vjp", CTYPE, {
|
dispatch_float_types(gx.dtype(), "layernorm_vjp", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<CTYPE>;
|
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
dispatch_block_dim(
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
auto kernel = cu::layer_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
auto kernel = cu::layer_norm_vjp<
|
||||||
|
DataType,
|
||||||
|
has_w_constant(),
|
||||||
|
block_dim(),
|
||||||
|
N_READS>;
|
||||||
|
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
g.data<DataType>(),
|
g.data<DataType>(),
|
||||||
|
@ -144,12 +144,13 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "logsumexp", CTYPE, {
|
dispatch_float_types(out.dtype(), "logsumexp", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<CTYPE>;
|
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
dispatch_block_dim(
|
||||||
auto kernel = cu::logsumexp<DataType, float, BLOCK_DIM, N_READS>;
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
auto kernel = cu::logsumexp<DataType, float, block_dim(), N_READS>;
|
||||||
|
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@ -28,7 +28,8 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&, this](cudaStream_t stream) {
|
encoder.launch_kernel([&, this](cudaStream_t stream) {
|
||||||
MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(out.dtype(), "Arange", CTYPE, {
|
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
|
||||||
|
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||||
using OutType = cuda_type_t<CTYPE>;
|
using OutType = cuda_type_t<CTYPE>;
|
||||||
CTYPE step =
|
CTYPE step =
|
||||||
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
|
||||||
|
@ -111,10 +111,11 @@ void all_reduce(
|
|||||||
encoder.add_temporary(intermediate);
|
encoder.add_temporary(intermediate);
|
||||||
encoder.set_output_array(intermediate);
|
encoder.set_output_array(intermediate);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
dispatch_all_types(dt, [&](auto type_tag) {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||||
kernel<<<blocks, threads, 0, stream>>>(
|
kernel<<<blocks, threads, 0, stream>>>(
|
||||||
static_cast<T*>(indata),
|
static_cast<T*>(indata),
|
||||||
@ -135,10 +136,11 @@ void all_reduce(
|
|||||||
|
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(dt, CTYPE, {
|
dispatch_all_types(dt, [&](auto type_tag) {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
|
||||||
kernel<<<blocks, threads, 0, stream>>>(
|
kernel<<<blocks, threads, 0, stream>>>(
|
||||||
static_cast<T*>(indata), out.data<U>(), block_step, insize);
|
static_cast<T*>(indata), out.data<U>(), block_step, insize);
|
||||||
|
@ -215,11 +215,12 @@ void col_reduce_looped(
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
@ -229,7 +230,8 @@ void col_reduce_looped(
|
|||||||
constexpr int BN = 32;
|
constexpr int BN = 32;
|
||||||
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
dim3 grid = output_grid_for_col_reduce(out, args, BN);
|
||||||
int blocks = BM * BN / N_READS;
|
int blocks = BM * BN / N_READS;
|
||||||
auto kernel = cu::col_reduce_looped<T, U, OP, NDIM, BM, BN, N_READS>;
|
auto kernel =
|
||||||
|
cu::col_reduce_looped<T, U, OP, reduce_ndim(), BM, BN, N_READS>;
|
||||||
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
|
kernel<<<grid, blocks, 0, stream>>>(indata, out.data<U>(), args);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@ -33,10 +33,11 @@ void init_reduce(
|
|||||||
|
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
auto kernel = cu::init_reduce<T, U, OP>;
|
auto kernel = cu::init_reduce<T, U, OP>;
|
||||||
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
|
||||||
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
dim3 block(grid.x < 1024 ? grid.x : 1024, 1, 1);
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mlx/backend/common/reduce.h"
|
#include "mlx/backend/common/reduce.h"
|
||||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
@ -9,43 +11,35 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Dispatch dynamic ndim to constexpr.
|
template <typename F>
|
||||||
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
void dispatch_reduce_ndim(int ndim, F&& f) {
|
||||||
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
if (ndim == 1) {
|
||||||
if (ndim == 1) { \
|
f(std::integral_constant<int, 1>{});
|
||||||
constexpr uint32_t NDIM = 1; \
|
} else if (ndim == 2) {
|
||||||
__VA_ARGS__; \
|
f(std::integral_constant<int, 2>{});
|
||||||
} else if (ndim == 2) { \
|
} else {
|
||||||
constexpr uint32_t NDIM = 2; \
|
f(std::integral_constant<int, 5>{});
|
||||||
__VA_ARGS__; \
|
|
||||||
} else { \
|
|
||||||
constexpr uint32_t NDIM = 5; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Dispatch reduce ops to constexpr.
|
template <typename F>
|
||||||
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) {
|
||||||
if (REDUCE == Reduce::ReduceType::And) { \
|
if (reduce_type == Reduce::ReduceType::And) {
|
||||||
using OP = cu::And; \
|
f(type_identity<cu::And>{});
|
||||||
__VA_ARGS__; \
|
} else if (reduce_type == Reduce::ReduceType::Or) {
|
||||||
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
f(type_identity<cu::Or>{});
|
||||||
using OP = cu::Or; \
|
} else if (reduce_type == Reduce::ReduceType::Sum) {
|
||||||
__VA_ARGS__; \
|
f(type_identity<cu::Sum>{});
|
||||||
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
} else if (reduce_type == Reduce::ReduceType::Prod) {
|
||||||
using OP = cu::Sum; \
|
f(type_identity<cu::Prod>{});
|
||||||
__VA_ARGS__; \
|
} else if (reduce_type == Reduce::ReduceType::Max) {
|
||||||
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
f(type_identity<cu::Max>{});
|
||||||
using OP = cu::Prod; \
|
} else if (reduce_type == Reduce::ReduceType::Min) {
|
||||||
__VA_ARGS__; \
|
f(type_identity<cu::Min>{});
|
||||||
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
} else {
|
||||||
using OP = cu::Max; \
|
throw std::invalid_argument("Unknown reduce type.");
|
||||||
__VA_ARGS__; \
|
|
||||||
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
|
||||||
using OP = cu::Min; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
} else { \
|
|
||||||
throw std::invalid_argument("Unknown reduce type."); \
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void all_reduce(
|
void all_reduce(
|
||||||
cu::CommandEncoder& encoder,
|
cu::CommandEncoder& encoder,
|
||||||
|
@ -246,10 +246,11 @@ void row_reduce_simple(
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
@ -293,10 +294,11 @@ void row_reduce_looped(
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
|
dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
|
||||||
using T = cuda_type_t<CTYPE>;
|
using OP = MLX_GET_TYPE(reduce_type_tag);
|
||||||
using U = cu::ReduceResult<OP, T>::type;
|
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
using U = typename cu::ReduceResult<OP, T>::type;
|
||||||
|
|
||||||
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
// Cub doesn't like const pointers for vectorized loads. (sigh)
|
||||||
T* indata = const_cast<T*>(in.data<T>());
|
T* indata = const_cast<T*>(in.data<T>());
|
||||||
@ -311,10 +313,16 @@ void row_reduce_looped(
|
|||||||
|
|
||||||
// Pick the kernel
|
// Pick the kernel
|
||||||
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
auto kernel = cu::row_reduce_looped<T, U, OP, 1, 32, N_READS>;
|
||||||
MLX_SWITCH_REDUCE_NDIM(args.reduce_ndim, NDIM, {
|
dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
|
||||||
MLX_SWITCH_BLOCK_DIM(threads, THREADS, {
|
dispatch_block_dim(threads, [&](auto threads_constant) {
|
||||||
kernel = cu::row_reduce_looped<T, U, OP, NDIM, THREADS, N_READS>;
|
kernel = cu::row_reduce_looped<
|
||||||
block.x = THREADS;
|
T,
|
||||||
|
U,
|
||||||
|
OP,
|
||||||
|
reduce_ndim(),
|
||||||
|
threads_constant(),
|
||||||
|
N_READS>;
|
||||||
|
block.x = threads_constant();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -225,12 +225,13 @@ void RMSNorm::eval_gpu(
|
|||||||
encoder.set_input_array(w);
|
encoder.set_input_array(w);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "rms_norm", CTYPE, {
|
dispatch_float_types(out.dtype(), "rms_norm", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<CTYPE>;
|
|
||||||
constexpr uint32_t N_READS = 4;
|
constexpr uint32_t N_READS = 4;
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
dispatch_block_dim(
|
||||||
auto kernel = cu::rms_norm<DataType, BLOCK_DIM, N_READS>;
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
auto kernel = cu::rms_norm<DataType, block_dim(), N_READS>;
|
||||||
|
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
out.data<DataType>(),
|
out.data<DataType>(),
|
||||||
@ -311,13 +312,19 @@ void RMSNormVJP::eval_gpu(
|
|||||||
encoder.set_output_array(gx);
|
encoder.set_output_array(gx);
|
||||||
encoder.set_output_array(gw_temp);
|
encoder.set_output_array(gw_temp);
|
||||||
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
encoder.launch_kernel([&, x = x, g = g](cudaStream_t stream) {
|
||||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(gx.dtype(), "rms_norm_vjp", CTYPE, {
|
dispatch_float_types(gx.dtype(), "rms_norm_vjp", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<CTYPE>;
|
dispatch_bool(has_w, [&](auto has_w_constant) {
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
MLX_SWITCH_BOOL(has_w, HAS_W, {
|
dispatch_block_dim(
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
auto kernel = cu::rms_norm_vjp<DataType, HAS_W, BLOCK_DIM, N_READS>;
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
constexpr int N_READS = 4;
|
||||||
|
auto kernel = cu::rms_norm_vjp<
|
||||||
|
DataType,
|
||||||
|
has_w_constant(),
|
||||||
|
block_dim(),
|
||||||
|
N_READS>;
|
||||||
|
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||||
x.data<DataType>(),
|
x.data<DataType>(),
|
||||||
w.data<DataType>(),
|
w.data<DataType>(),
|
||||||
g.data<DataType>(),
|
g.data<DataType>(),
|
||||||
|
@ -310,12 +310,12 @@ void RoPE::eval_gpu(
|
|||||||
encoder.set_input_array(offset);
|
encoder.set_input_array(offset);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(in.dtype(), "rope", CTYPE, {
|
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<CTYPE>;
|
dispatch_bool(traditional_, [&](auto traditional) {
|
||||||
MLX_SWITCH_BOOL(traditional_, TRADITIONAL, {
|
dispatch_bool(forward_, [&](auto forward) {
|
||||||
MLX_SWITCH_BOOL(forward_, FORWARD, {
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
if (single && !with_freqs) {
|
if (single && !with_freqs) {
|
||||||
auto kernel = cu::rope_single<DataType, TRADITIONAL, FORWARD>;
|
auto kernel = cu::rope_single<DataType, traditional(), forward()>;
|
||||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
@ -327,7 +327,8 @@ void RoPE::eval_gpu(
|
|||||||
mat_size,
|
mat_size,
|
||||||
dims);
|
dims);
|
||||||
} else if (single) {
|
} else if (single) {
|
||||||
auto kernel = cu::rope_single_freqs<DataType, TRADITIONAL, FORWARD>;
|
auto kernel =
|
||||||
|
cu::rope_single_freqs<DataType, traditional(), forward()>;
|
||||||
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
uint2 dims = make_uint2(dims_ / 2, in.size() / mat_size);
|
||||||
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
|
||||||
kernel<<<grid, block, 0, stream>>>(
|
kernel<<<grid, block, 0, stream>>>(
|
||||||
@ -340,7 +341,7 @@ void RoPE::eval_gpu(
|
|||||||
dims,
|
dims,
|
||||||
inputs[2].strides(0));
|
inputs[2].strides(0));
|
||||||
} else if (with_freqs) {
|
} else if (with_freqs) {
|
||||||
auto kernel = cu::rope_freqs<DataType, TRADITIONAL, FORWARD>;
|
auto kernel = cu::rope_freqs<DataType, traditional(), forward()>;
|
||||||
uint3 dims =
|
uint3 dims =
|
||||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||||
dims.z = (dims.z + 3) / 4;
|
dims.z = (dims.z + 3) / 4;
|
||||||
@ -358,7 +359,7 @@ void RoPE::eval_gpu(
|
|||||||
dims,
|
dims,
|
||||||
inputs[2].strides(0));
|
inputs[2].strides(0));
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::rope<DataType, TRADITIONAL, FORWARD>;
|
auto kernel = cu::rope<DataType, traditional(), forward()>;
|
||||||
uint3 dims =
|
uint3 dims =
|
||||||
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
make_uint3(dims_ / 2, in.shape(-2), in.size() / mat_size);
|
||||||
dims.z = (dims.z + 3) / 4;
|
dims.z = (dims.z + 3) / 4;
|
||||||
|
@ -142,15 +142,16 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_FLOAT_TYPES_CHECKED(out.dtype(), "softmax", CTYPE, {
|
dispatch_float_types(out.dtype(), "softmax", [&](auto type_tag) {
|
||||||
using DataType = cuda_type_t<CTYPE>;
|
|
||||||
constexpr int N_READS = 4;
|
constexpr int N_READS = 4;
|
||||||
MLX_SWITCH_BLOCK_DIM(cuda::ceil_div(axis_size, N_READS), BLOCK_DIM, {
|
dispatch_block_dim(
|
||||||
auto kernel = cu::softmax<DataType, DataType, BLOCK_DIM, N_READS>;
|
cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
|
||||||
|
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
auto kernel = cu::softmax<DataType, DataType, block_dim(), N_READS>;
|
||||||
if (precise) {
|
if (precise) {
|
||||||
kernel = cu::softmax<DataType, float, BLOCK_DIM, N_READS>;
|
kernel = cu::softmax<DataType, float, block_dim(), N_READS>;
|
||||||
}
|
}
|
||||||
kernel<<<n_rows, BLOCK_DIM, 0, stream>>>(
|
kernel<<<n_rows, block_dim(), 0, stream>>>(
|
||||||
in.data<DataType>(), out.data<DataType>(), axis_size);
|
in.data<DataType>(), out.data<DataType>(), axis_size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@ -76,6 +76,14 @@ void segmented_sort(cu::CommandEncoder& encoder, Args&&... args) {
|
|||||||
temp.data<void>(), size, args...));
|
temp.data<void>(), size, args...));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct OffsetTransform {
|
||||||
|
int nsort;
|
||||||
|
|
||||||
|
int __device__ operator()(int i) {
|
||||||
|
return i * nsort;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
||||||
array out = out_;
|
array out = out_;
|
||||||
auto& encoder = cu::get_command_encoder(s);
|
auto& encoder = cu::get_command_encoder(s);
|
||||||
@ -106,12 +114,12 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
|
dispatch_all_types(in.dtype(), [&](auto type_tag) {
|
||||||
|
using CTYPE = MLX_GET_TYPE(type_tag);
|
||||||
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
|
||||||
using Type = cuda_type_t<CTYPE>;
|
using Type = cuda_type_t<CTYPE>;
|
||||||
auto offsets = thrust::make_transform_iterator(
|
auto offsets = thrust::make_transform_iterator(
|
||||||
thrust::make_counting_iterator(0),
|
thrust::make_counting_iterator(0), OffsetTransform{nsort});
|
||||||
[nsort] __device__(int i) { return i * nsort; });
|
|
||||||
if (argsort) {
|
if (argsort) {
|
||||||
// Indices in the sorted dimension.
|
// Indices in the sorted dimension.
|
||||||
array indices(
|
array indices(
|
||||||
|
@ -92,39 +92,44 @@ void ternary_op_gpu_inplace(
|
|||||||
encoder.set_input_array(c);
|
encoder.set_input_array(c);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE, {
|
dispatch_all_types(out.dtype(), [&](auto type_tag) {
|
||||||
using DType = cuda_type_t<CTYPE>;
|
using DType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
|
||||||
|
|
||||||
auto topt = get_ternary_op_type(a, b, c);
|
auto topt = get_ternary_op_type(a, b, c);
|
||||||
if (topt == TernaryOpType::General) {
|
if (topt == TernaryOpType::General) {
|
||||||
auto [shape, strides] = collapse_contiguous_dims(a, b, c, out);
|
dispatch_bool(
|
||||||
|
a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
||||||
|
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX,
|
||||||
|
[&](auto large) {
|
||||||
|
using IdxT = std::conditional_t<large(), int64_t, int32_t>;
|
||||||
|
Shape shape;
|
||||||
|
std::vector<Strides> strides;
|
||||||
|
std::tie(shape, strides) = collapse_contiguous_dims(a, b, c, out);
|
||||||
auto& a_strides = strides[0];
|
auto& a_strides = strides[0];
|
||||||
auto& b_strides = strides[1];
|
auto& b_strides = strides[1];
|
||||||
auto& c_strides = strides[2];
|
auto& c_strides = strides[2];
|
||||||
bool large = a.data_size() > INT32_MAX || b.data_size() > INT32_MAX ||
|
|
||||||
c.data_size() > INT32_MAX || out.data_size() > INT32_MAX;
|
|
||||||
MLX_SWITCH_BOOL(large, LARGE, {
|
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, int32_t>;
|
|
||||||
int ndim = shape.size();
|
int ndim = shape.size();
|
||||||
if (ndim <= 3) {
|
if (ndim <= 3) {
|
||||||
MLX_SWITCH_1_2_3(ndim, NDIM, {
|
dispatch_1_2_3(ndim, [&](auto dims_constant) {
|
||||||
auto kernel = cu::ternary_g_nd<Op, DType, IdxT, NDIM>;
|
auto kernel =
|
||||||
|
cu::ternary_g_nd<Op, DType, IdxT, dims_constant()>;
|
||||||
auto [num_blocks, block_dims] =
|
auto [num_blocks, block_dims] =
|
||||||
get_launch_args(kernel, out, large);
|
get_launch_args(kernel, out, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
c.data<DType>(),
|
c.data<DType>(),
|
||||||
out.data<DType>(),
|
out.data<DType>(),
|
||||||
out.size(),
|
out.size(),
|
||||||
const_param<NDIM>(shape),
|
const_param<dims_constant()>(shape),
|
||||||
const_param<NDIM>(a_strides),
|
const_param<dims_constant()>(a_strides),
|
||||||
const_param<NDIM>(b_strides),
|
const_param<dims_constant()>(b_strides),
|
||||||
const_param<NDIM>(c_strides));
|
const_param<dims_constant()>(c_strides));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
auto kernel = cu::ternary_g<Op, DType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(kernel, out, large);
|
auto [num_blocks, block_dims] =
|
||||||
|
get_launch_args(kernel, out, large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
@ -139,11 +144,11 @@ void ternary_op_gpu_inplace(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
MLX_SWITCH_BOOL(out.data_size() > UINT32_MAX, LARGE, {
|
dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
|
||||||
using IdxT = std::conditional_t<LARGE, int64_t, uint32_t>;
|
using IdxT = std::conditional_t<large(), int64_t, uint32_t>;
|
||||||
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
auto kernel = cu::ternary_v<Op, DType, IdxT>;
|
||||||
auto [num_blocks, block_dims] = get_launch_args(
|
auto [num_blocks, block_dims] = get_launch_args(
|
||||||
kernel, out.data_size(), out.shape(), out.strides(), LARGE);
|
kernel, out.data_size(), out.shape(), out.strides(), large());
|
||||||
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
kernel<<<num_blocks, block_dims, 0, stream>>>(
|
||||||
a.data<bool>(),
|
a.data<bool>(),
|
||||||
b.data<DType>(),
|
b.data<DType>(),
|
||||||
|
@ -79,8 +79,10 @@ void unary_op_gpu_inplace(
|
|||||||
encoder.set_input_array(in);
|
encoder.set_input_array(in);
|
||||||
encoder.set_output_array(out);
|
encoder.set_output_array(out);
|
||||||
encoder.launch_kernel([&](cudaStream_t stream) {
|
encoder.launch_kernel([&](cudaStream_t stream) {
|
||||||
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE_IN, {
|
dispatch_all_types(in.dtype(), [&](auto in_type_tag) {
|
||||||
MLX_SWITCH_ALL_TYPES(out.dtype(), CTYPE_OUT, {
|
dispatch_all_types(out.dtype(), [&](auto out_type_tag) {
|
||||||
|
using CTYPE_IN = MLX_GET_TYPE(in_type_tag);
|
||||||
|
using CTYPE_OUT = MLX_GET_TYPE(out_type_tag);
|
||||||
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
if constexpr (cu::supports_unary_op<Op, CTYPE_IN, CTYPE_OUT>()) {
|
||||||
using InType = cuda_type_t<CTYPE_IN>;
|
using InType = cuda_type_t<CTYPE_IN>;
|
||||||
using OutType = cuda_type_t<CTYPE_OUT>;
|
using OutType = cuda_type_t<CTYPE_OUT>;
|
||||||
|
@ -25,22 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
const char* dtype_to_cuda_type(const Dtype& dtype) {
|
||||||
if (dtype == float16) {
|
switch (dtype) {
|
||||||
|
case bool_:
|
||||||
|
return "bool";
|
||||||
|
case int8:
|
||||||
|
return "int8_t";
|
||||||
|
case int16:
|
||||||
|
return "int16_t";
|
||||||
|
case int32:
|
||||||
|
return "int32_t";
|
||||||
|
case int64:
|
||||||
|
return "int64_t";
|
||||||
|
case uint8:
|
||||||
|
return "uint8_t";
|
||||||
|
case uint16:
|
||||||
|
return "uint16_t";
|
||||||
|
case uint32:
|
||||||
|
return "uint32_t";
|
||||||
|
case uint64:
|
||||||
|
return "uint64_t";
|
||||||
|
case float16:
|
||||||
return "__half";
|
return "__half";
|
||||||
}
|
case bfloat16:
|
||||||
if (dtype == bfloat16) {
|
|
||||||
return "__nv_bfloat16";
|
return "__nv_bfloat16";
|
||||||
}
|
case float32:
|
||||||
if (dtype == complex64) {
|
return "float";
|
||||||
|
case float64:
|
||||||
|
return "double";
|
||||||
|
case complex64:
|
||||||
return "cuComplex";
|
return "cuComplex";
|
||||||
|
default:
|
||||||
|
return "unknown";
|
||||||
}
|
}
|
||||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
|
||||||
if (dtype == DTYPE) { \
|
|
||||||
return #CPP_TYPE; \
|
|
||||||
}
|
|
||||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
|
||||||
#undef SPECIALIZE_DtypeToString
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -5,16 +5,38 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
const char* dtype_to_string(Dtype arg) {
|
const char* dtype_to_string(Dtype arg) {
|
||||||
if (arg == bool_) {
|
switch (arg) {
|
||||||
|
case bool_:
|
||||||
return "bool";
|
return "bool";
|
||||||
|
case int8:
|
||||||
|
return "int8";
|
||||||
|
case int16:
|
||||||
|
return "int16";
|
||||||
|
case int32:
|
||||||
|
return "int32";
|
||||||
|
case int64:
|
||||||
|
return "int64";
|
||||||
|
case uint8:
|
||||||
|
return "uint8";
|
||||||
|
case uint16:
|
||||||
|
return "uint16";
|
||||||
|
case uint32:
|
||||||
|
return "uint32";
|
||||||
|
case uint64:
|
||||||
|
return "uint64";
|
||||||
|
case float16:
|
||||||
|
return "float16";
|
||||||
|
case bfloat16:
|
||||||
|
return "bfloat16";
|
||||||
|
case float32:
|
||||||
|
return "float32";
|
||||||
|
case float64:
|
||||||
|
return "float64";
|
||||||
|
case complex64:
|
||||||
|
return "complex64";
|
||||||
|
default:
|
||||||
|
return "unknown";
|
||||||
}
|
}
|
||||||
#define SPECIALIZE_DtypeToString(CPP_TYPE, DTYPE) \
|
|
||||||
if (DTYPE == arg) { \
|
|
||||||
return #DTYPE; \
|
|
||||||
}
|
|
||||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToString)
|
|
||||||
#undef SPECIALIZE_DtypeToString
|
|
||||||
return "(unknown)";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1,207 +1,106 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
// Copyright © Meta Platforms, Inc. and affiliates.
|
|
||||||
//
|
|
||||||
// This source code is licensed under the BSD-style license found in
|
|
||||||
// https://github.com/pytorch/executorch/blob/main/LICENSE
|
|
||||||
//
|
|
||||||
// Forked from
|
|
||||||
// https://github.com/pytorch/executorch/blob/main/runtime/core/exec_aten/util/scalar_type_util.h
|
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlx/dtype.h"
|
#include <sstream>
|
||||||
|
|
||||||
#include <fmt/format.h>
|
#include "mlx/dtype.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
// Return string representation of dtype.
|
// Return string representation of dtype.
|
||||||
const char* dtype_to_string(Dtype arg);
|
const char* dtype_to_string(Dtype arg);
|
||||||
|
|
||||||
// Macros that iterate across different subsets of Dtypes.
|
#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \
|
||||||
//
|
case DTYPE: \
|
||||||
// For all of these macros, the final `_` parameter is the name of another macro
|
f(type_identity<TYPE>{}); \
|
||||||
// that takes two parameters: the name of a C type, and the name of the
|
break
|
||||||
// corresponding Dtype enumerator.
|
|
||||||
//
|
|
||||||
// Note that these macros should use fully-qualified namespaces (starting with
|
|
||||||
// `::`) to ensure that they can be called safely in any arbitrary namespace.
|
|
||||||
#define MLX_FORALL_INT_TYPES(_) \
|
|
||||||
_(uint8_t, uint8) \
|
|
||||||
_(uint16_t, uint16) \
|
|
||||||
_(uint32_t, uint32) \
|
|
||||||
_(uint64_t, uint64) \
|
|
||||||
_(int8_t, int8) \
|
|
||||||
_(int16_t, int16) \
|
|
||||||
_(int32_t, int32) \
|
|
||||||
_(int64_t, int64)
|
|
||||||
|
|
||||||
#define MLX_FORALL_FLOAT_TYPES(_) \
|
#define MLX_INTERNAL_DTYPE_SWITCH_INTS() \
|
||||||
_(float16_t, float16) \
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \
|
||||||
_(float, float32) \
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \
|
||||||
_(double, float64) \
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \
|
||||||
_(bfloat16_t, bfloat16)
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \
|
||||||
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \
|
||||||
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \
|
||||||
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \
|
||||||
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t)
|
||||||
|
|
||||||
// Calls the provided macro on every Dtype, providing the C type and the
|
#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \
|
||||||
// Dtype name to each call.
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \
|
||||||
//
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \
|
||||||
// @param _ A macro that takes two parameters: the name of a C type, and the
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \
|
||||||
// name of the corresponding Dtype enumerator.
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double)
|
||||||
#define MLX_FORALL_DTYPES(_) \
|
|
||||||
MLX_FORALL_INT_TYPES(_) \
|
|
||||||
MLX_FORALL_FLOAT_TYPES(_) \
|
|
||||||
_(bool, bool_) \
|
|
||||||
_(complex64_t, complex64)
|
|
||||||
|
|
||||||
// Maps Dtypes to C++ types.
|
// This already exists in C++20 but in C++20 we can also just use templated
|
||||||
template <Dtype::Val N>
|
// lambdas which will make this so much nicer.
|
||||||
struct DtypeToCppType;
|
|
||||||
|
|
||||||
#define SPECIALIZE_DtypeToCppType(CPP_TYPE, DTYPE) \
|
|
||||||
template <> \
|
|
||||||
struct DtypeToCppType<Dtype::Val::DTYPE> { \
|
|
||||||
using type = CPP_TYPE; \
|
|
||||||
};
|
|
||||||
|
|
||||||
MLX_FORALL_DTYPES(SPECIALIZE_DtypeToCppType)
|
|
||||||
|
|
||||||
#undef SPECIALIZE_DtypeToCppType
|
|
||||||
|
|
||||||
// Maps C++ types to Dtypes.
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct CppTypeToDtype;
|
struct type_identity {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
#define SPECIALIZE_CppTypeToDtype(CPP_TYPE, DTYPE) \
|
#define MLX_GET_TYPE(x) typename decltype(x)::type
|
||||||
template <> \
|
#define MLX_GET_VALUE(x) decltype(x)::value
|
||||||
struct CppTypeToDtype<CPP_TYPE> \
|
|
||||||
: std::integral_constant<Dtype::Val, Dtype::Val::DTYPE> {};
|
|
||||||
|
|
||||||
MLX_FORALL_DTYPES(SPECIALIZE_CppTypeToDtype)
|
template <typename F>
|
||||||
|
void dispatch_all_types(Dtype dt, F&& f) {
|
||||||
#undef SPECIALIZE_CppTypeToDtype
|
switch (dt) {
|
||||||
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);
|
||||||
// Helper macros for switch case macros (see below)
|
MLX_INTERNAL_DTYPE_SWITCH_INTS();
|
||||||
//
|
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
|
||||||
// These macros are not meant to be used directly. They provide an easy way to
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);
|
||||||
// generate a switch statement that can handle subsets of Dtypes supported.
|
|
||||||
|
|
||||||
#define MLX_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
|
|
||||||
case enum_type: { \
|
|
||||||
using CTYPE_ALIAS = ::mlx::core::DtypeToCppType<enum_type>::type; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
break; \
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define MLX_INTERNAL_SWITCH_CHECKED(TYPE, NAME, ...) \
|
template <typename F>
|
||||||
switch (TYPE) { \
|
void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) {
|
||||||
__VA_ARGS__ \
|
switch (dt) {
|
||||||
default: \
|
MLX_INTERNAL_DTYPE_SWITCH_INTS();
|
||||||
throw std::invalid_argument(fmt::format( \
|
default:
|
||||||
"Unhandled dtype %s for %s", dtype_to_string(TYPE), NAME)); \
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Only integer types supported but " << dt
|
||||||
|
<< " was provided";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \
|
template <typename F>
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) {
|
||||||
::mlx::core::Dtype::Val::uint8, CTYPE_ALIAS, __VA_ARGS__) \
|
switch (dt) {
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
|
||||||
::mlx::core::Dtype::Val::uint16, CTYPE_ALIAS, __VA_ARGS__) \
|
default:
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
std::ostringstream msg;
|
||||||
::mlx::core::Dtype::Val::uint32, CTYPE_ALIAS, __VA_ARGS__) \
|
msg << tag << " Only float types supported but " << dt << " was provided";
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
throw std::invalid_argument(msg.str());
|
||||||
::mlx::core::Dtype::Val::uint64, CTYPE_ALIAS, __VA_ARGS__) \
|
}
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
}
|
||||||
::mlx::core::Dtype::Val::int8, CTYPE_ALIAS, __VA_ARGS__) \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
|
||||||
::mlx::core::Dtype::Val::int16, CTYPE_ALIAS, __VA_ARGS__) \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
|
||||||
::mlx::core::Dtype::Val::int32, CTYPE_ALIAS, __VA_ARGS__) \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
|
||||||
::mlx::core::Dtype::Val::int64, CTYPE_ALIAS, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, ...) \
|
template <typename F>
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) {
|
||||||
::mlx::core::Dtype::Val::float16, CTYPE_ALIAS, __VA_ARGS__) \
|
switch (dt) {
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
MLX_INTERNAL_DTYPE_SWITCH_INTS();
|
||||||
::mlx::core::Dtype::Val::float32, CTYPE_ALIAS, __VA_ARGS__) \
|
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
default:
|
||||||
::mlx::core::Dtype::Val::float64, CTYPE_ALIAS, __VA_ARGS__) \
|
std::ostringstream msg;
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
msg << tag << " Only integer and float types supported but " << dt
|
||||||
::mlx::core::Dtype::Val::bfloat16, CTYPE_ALIAS, __VA_ARGS__)
|
<< " was provided";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, ...) \
|
template <typename F>
|
||||||
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) {
|
||||||
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__)
|
switch (dt) {
|
||||||
|
MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);
|
||||||
#define MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \
|
MLX_INTERNAL_DTYPE_SWITCH_INTS();
|
||||||
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
default:
|
||||||
::mlx::core::Dtype::Val::bool_, CTYPE_ALIAS, __VA_ARGS__)
|
std::ostringstream msg;
|
||||||
|
msg << tag << " Only real numbers supported but " << dt
|
||||||
#define MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, ...) \
|
<< " was provided";
|
||||||
MLX_INTERNAL_SWITCH_CASE( \
|
throw std::invalid_argument(msg.str());
|
||||||
::mlx::core::Dtype::Val::complex64, CTYPE_ALIAS, __VA_ARGS__)
|
}
|
||||||
|
}
|
||||||
#define MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__) \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)
|
|
||||||
|
|
||||||
// Switch case macros
|
|
||||||
//
|
|
||||||
// These macros provide an easy way to generate switch statements that apply a
|
|
||||||
// common lambda function to subsets of Dtypes supported by MLX.
|
|
||||||
// The lambda function can type specialize to the ctype associated with the
|
|
||||||
// Dtype being handled through an alias passed as the CTYPE_ALIAS argument.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
// - ADDITIONAL: Additional Dtype case to add
|
|
||||||
// - TYPE: The Dtype to handle through the switch statement
|
|
||||||
// - NAME: A name for this operation which will be used in error messages
|
|
||||||
// - CTYPE_ALIAS: A typedef for the ctype associated with the Dtype.
|
|
||||||
// - ...: A statement to be applied to each Dtype case
|
|
||||||
//
|
|
||||||
// An example usage is:
|
|
||||||
//
|
|
||||||
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE, {
|
|
||||||
// output.data<CTYPE>[0] = input.data<CTYPE>[0];
|
|
||||||
// });
|
|
||||||
//
|
|
||||||
// Note that these can be nested as well:
|
|
||||||
//
|
|
||||||
// MLX_SWITCH_ALL_TYPES(input.dtype(), CTYPE_IN, {
|
|
||||||
// MLX_SWITCH_ALL_TYPES(output.dtype(), CTYPE_OUT, {
|
|
||||||
// output.data<CTYPE_OUT>[0] = input.data<CTYPE_IN>[0];
|
|
||||||
// });
|
|
||||||
// });
|
|
||||||
//
|
|
||||||
// These macros are adapted from Dispatch.h in the ATen library. The primary
|
|
||||||
// difference is that the CTYPE_ALIAS argument is exposed to users, which is
|
|
||||||
// used to alias the ctype associated with the Dtype that is being handled.
|
|
||||||
|
|
||||||
#define MLX_SWITCH_ALL_TYPES(TYPE, CTYPE_ALIAS, ...) \
|
|
||||||
switch (TYPE) { MLX_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__) }
|
|
||||||
|
|
||||||
#define MLX_SWITCH_INT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
|
||||||
MLX_INTERNAL_SWITCH_CHECKED( \
|
|
||||||
TYPE, \
|
|
||||||
NAME, \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
|
||||||
|
|
||||||
#define MLX_SWITCH_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
|
||||||
MLX_INTERNAL_SWITCH_CHECKED( \
|
|
||||||
TYPE, \
|
|
||||||
NAME, \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
|
||||||
|
|
||||||
#define MLX_SWITCH_INT_FLOAT_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
|
||||||
MLX_INTERNAL_SWITCH_CHECKED( \
|
|
||||||
TYPE, \
|
|
||||||
NAME, \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE_INT_FLOAT_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
|
||||||
|
|
||||||
#define MLX_SWITCH_REAL_TYPES_CHECKED(TYPE, NAME, CTYPE_ALIAS, ...) \
|
|
||||||
MLX_INTERNAL_SWITCH_CHECKED( \
|
|
||||||
TYPE, \
|
|
||||||
NAME, \
|
|
||||||
MLX_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -253,7 +253,9 @@ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) {
|
|||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, array a) {
|
std::ostream& operator<<(std::ostream& os, array a) {
|
||||||
a.eval();
|
a.eval();
|
||||||
MLX_SWITCH_ALL_TYPES(a.dtype(), CTYPE, print_array<CTYPE>(os, a));
|
dispatch_all_types(a.dtype(), [&](auto type_tag) {
|
||||||
|
print_array<MLX_GET_TYPE(type_tag)>(os, a);
|
||||||
|
});
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,8 +323,9 @@ void set_iinfo_limits(int64_t& min, uint64_t& max) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
iinfo::iinfo(Dtype dtype) : dtype(dtype) {
|
||||||
MLX_SWITCH_INT_TYPES_CHECKED(
|
dispatch_int_types(dtype, "[iinfo]", [&](auto type_tag) {
|
||||||
dtype, "[iinfo]", CTYPE, set_iinfo_limits<CTYPE>(min, max));
|
set_iinfo_limits<MLX_GET_TYPE(type_tag)>(min, max);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user