MLX_SWITCH macros to templates (#2320)

This commit is contained in:
Angelos Katharopoulos
2025-07-01 01:33:44 -07:00
committed by GitHub
parent 33bf1a244b
commit 3d5e17e507
27 changed files with 693 additions and 692 deletions

View File

@@ -6,6 +6,8 @@
#pragma once
#include <type_traits>
#include "mlx/array.h"
#include "mlx/backend/cuda/device/utils.cuh"
@@ -17,60 +19,46 @@
namespace mlx::core {
// Convert a number between 1~3 to constexpr.
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
switch (N) { \
case 1: { \
constexpr int NDIM = 1; \
__VA_ARGS__; \
break; \
} \
case 2: { \
constexpr int NDIM = 2; \
__VA_ARGS__; \
break; \
} \
case 3: { \
constexpr int NDIM = 3; \
__VA_ARGS__; \
break; \
} \
template <typename F>
void dispatch_1_2_3(int n, F&& f) {
switch (n) {
case 1:
f(std::integral_constant<int, 1>{});
break;
case 2:
f(std::integral_constant<int, 2>{});
break;
case 3:
f(std::integral_constant<int, 3>{});
break;
}
}
// Like MLX_SWITCH_ALL_TYPES but for booleans.
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
if (BOOL) { \
constexpr bool BOOL_ALIAS = true; \
__VA_ARGS__; \
} else { \
constexpr bool BOOL_ALIAS = false; \
__VA_ARGS__; \
template <typename F>
void dispatch_bool(bool v, F&& f) {
if (v) {
f(std::true_type{});
} else {
f(std::false_type{});
}
}
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
{ \
uint32_t _num_threads = NUM_THREADS; \
if (_num_threads <= WARP_SIZE) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 2) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 4) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 8) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
__VA_ARGS__; \
} else if (_num_threads <= WARP_SIZE * 16) { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
__VA_ARGS__; \
} else { \
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
__VA_ARGS__; \
} \
template <typename F>
void dispatch_block_dim(int threads, F&& f) {
if (threads <= WARP_SIZE) {
f(std::integral_constant<int, WARP_SIZE>{});
} else if (threads <= WARP_SIZE * 2) {
f(std::integral_constant<int, WARP_SIZE * 2>{});
} else if (threads <= WARP_SIZE * 4) {
f(std::integral_constant<int, WARP_SIZE * 4>{});
} else if (threads <= WARP_SIZE * 8) {
f(std::integral_constant<int, WARP_SIZE * 8>{});
} else if (threads <= WARP_SIZE * 16) {
f(std::integral_constant<int, WARP_SIZE * 16>{});
} else {
f(std::integral_constant<int, WARP_SIZE * 32>{});
}
}
// Maps CPU types to CUDA types.
template <typename T>