From e9fbdd20fb8c04363910e0a41c3999a005067a49 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 16 Jun 2025 18:35:49 +0200 Subject: [PATCH] Helper function to parse types --- mlx/distributed/nccl/nccl.cpp | 62 +++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index 7cb37a05a..8427ecf01 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "mlx/backend/cuda/device.h" #include "mlx/distributed/distributed.h" @@ -187,30 +188,46 @@ inline void bootstrapUniqueId( } } -inline ncclDataType_t datatype(const array& arr) { +template +struct type_identity { + using type = T; +}; + +template +void dispatch_dtype(const array& arr, F&& f) { switch (arr.dtype()) { case bool_: throw std::invalid_argument("[nccl] Boolean arrays not supported"); case int8: - return ncclChar; + f(type_identity{}, ncclChar); + break; case uint8: - return ncclUint8; + f(type_identity{}, ncclUint8); + break; case int32: - return ncclInt; + f(type_identity{}, ncclInt); + break; case uint32: - return ncclUint32; + f(type_identity{}, ncclUint32); + break; case int64: - return ncclInt64; + f(type_identity{}, ncclInt64); + break; case uint64: - return ncclUint64; + f(type_identity{}, ncclUint64); + break; case float16: - return ncclHalf; - case float32: - return ncclFloat; - case float64: - return ncclDouble; + f(type_identity{}, ncclHalf); + break; case bfloat16: - return ncclBfloat16; + f(type_identity{}, ncclBfloat16); + break; + case float32: + f(type_identity{}, ncclFloat); + break; + case float64: + f(type_identity{}, ncclDouble); + break; default: throw std::invalid_argument("[nccl] Unknown or unsupported dtype"); } @@ -259,7 +276,10 @@ class NCCLGroup : public GroupImpl { throw std::runtime_error( "[nccl] Input and output arrays must have the same size."); } - all_reduce_impl(input, output, stream, ncclSum); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclSum); + }); } virtual std::shared_ptr split(int color, int key = -1) override { @@ -290,7 +310,10 @@ class NCCLGroup : public GroupImpl { throw std::runtime_error( "[nccl] Input and output arrays must have the same size."); } - all_reduce_impl(input, output, stream, ncclMax); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclMax); + }); } void all_min(const array& input, array& output, Stream stream) override { @@ -298,7 +321,10 @@ class NCCLGroup : public GroupImpl { throw std::runtime_error( "[nccl] Input and output arrays must have the same size."); } - all_reduce_impl(input, output, stream, ncclMin); + detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) { + using T = typename decltype(type_tag)::type; + all_reduce_impl(input, output, stream, dt, ncclMin); + }); } template @@ -306,9 +332,8 @@ class NCCLGroup : public GroupImpl { const array& input, array& output, Stream stream, + ncclDataType_t dt, ncclRedOp_t op) { - ncclDataType_t dt = detail::datatype(input); - CHECK_NCCL(ncclAllReduce( input.data(), output.data(), @@ -317,6 +342,7 @@ class NCCLGroup : public GroupImpl { op, comm_, stream_)); + cudaStreamSynchronize(stream_); } int rank_, size_;