mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
MLX_SWITCH macros to templates (#2320)
This commit is contained in:
committed by
GitHub
parent
33bf1a244b
commit
3d5e17e507
@@ -1,5 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/backend/common/reduce.h"
|
||||
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
@@ -9,43 +11,35 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Dispatch dynamic ndim to constexpr.
|
||||
// The behavior follows get_kernel_reduce_ndim in metal/reduce.cpp file.
|
||||
#define MLX_SWITCH_REDUCE_NDIM(ndim, NDIM, ...) \
|
||||
if (ndim == 1) { \
|
||||
constexpr uint32_t NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
} else if (ndim == 2) { \
|
||||
constexpr uint32_t NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t NDIM = 5; \
|
||||
__VA_ARGS__; \
|
||||
template <typename F>
|
||||
void dispatch_reduce_ndim(int ndim, F&& f) {
|
||||
if (ndim == 1) {
|
||||
f(std::integral_constant<int, 1>{});
|
||||
} else if (ndim == 2) {
|
||||
f(std::integral_constant<int, 2>{});
|
||||
} else {
|
||||
f(std::integral_constant<int, 5>{});
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch reduce ops to constexpr.
|
||||
#define MLX_SWITCH_REDUCE_OPS(REDUCE, OP, ...) \
|
||||
if (REDUCE == Reduce::ReduceType::And) { \
|
||||
using OP = cu::And; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Or) { \
|
||||
using OP = cu::Or; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Sum) { \
|
||||
using OP = cu::Sum; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Prod) { \
|
||||
using OP = cu::Prod; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Max) { \
|
||||
using OP = cu::Max; \
|
||||
__VA_ARGS__; \
|
||||
} else if (REDUCE == Reduce::ReduceType::Min) { \
|
||||
using OP = cu::Min; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
throw std::invalid_argument("Unknown reduce type."); \
|
||||
template <typename F>
|
||||
void dispatch_reduce_ops(Reduce::ReduceType reduce_type, F&& f) {
|
||||
if (reduce_type == Reduce::ReduceType::And) {
|
||||
f(type_identity<cu::And>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Or) {
|
||||
f(type_identity<cu::Or>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Sum) {
|
||||
f(type_identity<cu::Sum>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Prod) {
|
||||
f(type_identity<cu::Prod>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Max) {
|
||||
f(type_identity<cu::Max>{});
|
||||
} else if (reduce_type == Reduce::ReduceType::Min) {
|
||||
f(type_identity<cu::Min>{});
|
||||
} else {
|
||||
throw std::invalid_argument("Unknown reduce type.");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce(
|
||||
cu::CommandEncoder& encoder,
|
||||
|
||||
Reference in New Issue
Block a user