From cd53eb1ae30ada5d17e487b29e1d421d51298942 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 20 Aug 2025 15:09:41 +0200 Subject: [PATCH] dispatch types with dtype_utils --- mlx/distributed/nccl/nccl.cpp | 85 +++++++++++++++++------------------ mlx/distributed/ring/ring.cpp | 1 + 2 files changed, 41 insertions(+), 45 deletions(-) diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 7d8358fb3..23176c81b 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -16,6 +16,7 @@ #include "mlx/backend/cuda/device.h" #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" +#include "mlx/dtype_utils.h" namespace mlx::core::distributed::nccl { @@ -47,8 +48,47 @@ namespace mlx::core::distributed::nccl { } \ } while (0) +#define MLX_NCCL_TYPE_LIST(X) \ + X(int8_t, ncclChar) \ + X(uint8_t, ncclUint8) \ + X(int32_t, ncclInt) \ + X(uint32_t, ncclUint32) \ + X(int64_t, ncclInt64) \ + X(uint64_t, ncclUint64) \ + X(float16_t, ncclHalf) \ + X(bfloat16_t, ncclBfloat16) \ + X(float, ncclFloat) \ + X(double, ncclDouble) + +template +struct nccl_map { + static constexpr bool ok = false; // default: unsupported +}; + +#define MLX_DEF_NCCL_MAP(T, E) \ + template <> \ + struct nccl_map { \ + static constexpr bool ok = true; \ + static constexpr ncclDataType_t value = E; \ + }; + +MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP) +#undef MLX_DEF_NCCL_MAP + namespace detail { +template +void dispatch_dtype(const array& arr, F&& f) { + dispatch_all_types(arr.dtype(), [&](auto type_tag) { + using T = MLX_GET_TYPE(type_tag); + if constexpr (nccl_map::ok) { + f(type_tag, nccl_map::value); + } else { + throw std::invalid_argument("[nccl] Unknown or unsupported dtype"); + } + }); +} + inline void sendAll(int sock, const void* buf, size_t len) { const char* ptr = reinterpret_cast(buf); while (len > 0) { @@ -189,51 +229,6 @@ inline void bootstrap_unique_id( } } -template -struct type_identity { - using type = T; -}; - -template -void dispatch_dtype(const array& arr, F&& f) { - switch (arr.dtype()) { - case bool_: - throw std::invalid_argument("[nccl] Boolean arrays not supported"); - case int8: - f(type_identity{}, ncclChar); - break; - case uint8: - f(type_identity{}, ncclUint8); - break; - case int32: - f(type_identity{}, ncclInt); - break; - case uint32: - f(type_identity{}, ncclUint32); - break; - case int64: - f(type_identity{}, ncclInt64); - break; - case uint64: - f(type_identity{}, ncclUint64); - break; - case float16: - f(type_identity{}, ncclHalf); - break; - case bfloat16: - f(type_identity{}, ncclBfloat16); - break; - case float32: - f(type_identity{}, ncclFloat); - break; - case float64: - f(type_identity{}, ncclDouble); - break; - default: - throw std::invalid_argument("[nccl] Unknown or unsupported dtype"); - } -} - } // namespace detail using GroupImpl = mlx::core::distributed::detail::GroupImpl; diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index ce0967d53..b31274e23 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -975,6 +975,7 @@ class RingGroup : public GroupImpl { int rank_; int size_; + bool verbose_; ThreadPool pool_;