mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
MLX_SWITCH macros to templates (#2320)
This commit is contained in:
committed by
GitHub
parent
33bf1a244b
commit
3d5e17e507
@@ -6,6 +6,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/backend/cuda/device/utils.cuh"
|
||||
|
||||
@@ -17,60 +19,46 @@
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Convert a number between 1~3 to constexpr.
|
||||
#define MLX_SWITCH_1_2_3(N, NDIM, ...) \
|
||||
switch (N) { \
|
||||
case 1: { \
|
||||
constexpr int NDIM = 1; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr int NDIM = 2; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case 3: { \
|
||||
constexpr int NDIM = 3; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
template <typename F>
|
||||
void dispatch_1_2_3(int n, F&& f) {
|
||||
switch (n) {
|
||||
case 1:
|
||||
f(std::integral_constant<int, 1>{});
|
||||
break;
|
||||
case 2:
|
||||
f(std::integral_constant<int, 2>{});
|
||||
break;
|
||||
case 3:
|
||||
f(std::integral_constant<int, 3>{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Like MLX_SWITCH_ALL_TYPES but for booleans.
|
||||
#define MLX_SWITCH_BOOL(BOOL, BOOL_ALIAS, ...) \
|
||||
if (BOOL) { \
|
||||
constexpr bool BOOL_ALIAS = true; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr bool BOOL_ALIAS = false; \
|
||||
__VA_ARGS__; \
|
||||
template <typename F>
|
||||
void dispatch_bool(bool v, F&& f) {
|
||||
if (v) {
|
||||
f(std::true_type{});
|
||||
} else {
|
||||
f(std::false_type{});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert a block_dim to constexpr between WARP_SIZE and WARP_SIZE ^ 2.
|
||||
#define MLX_SWITCH_BLOCK_DIM(NUM_THREADS, BLOCK_DIM, ...) \
|
||||
{ \
|
||||
uint32_t _num_threads = NUM_THREADS; \
|
||||
if (_num_threads <= WARP_SIZE) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 2) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 2; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 4) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 4; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 8) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 8; \
|
||||
__VA_ARGS__; \
|
||||
} else if (_num_threads <= WARP_SIZE * 16) { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * 16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
constexpr uint32_t BLOCK_DIM = WARP_SIZE * WARP_SIZE; \
|
||||
__VA_ARGS__; \
|
||||
} \
|
||||
template <typename F>
|
||||
void dispatch_block_dim(int threads, F&& f) {
|
||||
if (threads <= WARP_SIZE) {
|
||||
f(std::integral_constant<int, WARP_SIZE>{});
|
||||
} else if (threads <= WARP_SIZE * 2) {
|
||||
f(std::integral_constant<int, WARP_SIZE * 2>{});
|
||||
} else if (threads <= WARP_SIZE * 4) {
|
||||
f(std::integral_constant<int, WARP_SIZE * 4>{});
|
||||
} else if (threads <= WARP_SIZE * 8) {
|
||||
f(std::integral_constant<int, WARP_SIZE * 8>{});
|
||||
} else if (threads <= WARP_SIZE * 16) {
|
||||
f(std::integral_constant<int, WARP_SIZE * 16>{});
|
||||
} else {
|
||||
f(std::integral_constant<int, WARP_SIZE * 32>{});
|
||||
}
|
||||
}
|
||||
|
||||
// Maps CPU types to CUDA types.
|
||||
template <typename T>
|
||||
|
||||
Reference in New Issue
Block a user