Helper function to parse types

This commit is contained in:
Anastasiia Filippova 2025-06-16 18:35:49 +02:00
parent f15a127900
commit e9fbdd20fb

View File

@ -11,6 +11,7 @@
#include <mutex> #include <mutex>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <type_traits>
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
@ -187,30 +188,46 @@ inline void bootstrapUniqueId(
} }
} }
inline ncclDataType_t datatype(const array& arr) { template <typename T>
struct type_identity {
using type = T;
};
template <typename F>
void dispatch_dtype(const array& arr, F&& f) {
switch (arr.dtype()) { switch (arr.dtype()) {
case bool_: case bool_:
throw std::invalid_argument("[nccl] Boolean arrays not supported"); throw std::invalid_argument("[nccl] Boolean arrays not supported");
case int8: case int8:
return ncclChar; f(type_identity<int8_t>{}, ncclChar);
break;
case uint8: case uint8:
return ncclUint8; f(type_identity<uint8_t>{}, ncclUint8);
break;
case int32: case int32:
return ncclInt; f(type_identity<int32_t>{}, ncclInt);
break;
case uint32: case uint32:
return ncclUint32; f(type_identity<uint32_t>{}, ncclUint32);
break;
case int64: case int64:
return ncclInt64; f(type_identity<int64_t>{}, ncclInt64);
break;
case uint64: case uint64:
return ncclUint64; f(type_identity<uint64_t>{}, ncclUint64);
break;
case float16: case float16:
return ncclHalf; f(type_identity<float16_t>{}, ncclHalf);
case float32: break;
return ncclFloat;
case float64:
return ncclDouble;
case bfloat16: case bfloat16:
return ncclBfloat16; f(type_identity<bfloat16_t>{}, ncclBfloat16);
break;
case float32:
f(type_identity<float>{}, ncclFloat);
break;
case float64:
f(type_identity<double>{}, ncclDouble);
break;
default: default:
throw std::invalid_argument("[nccl] Unknown or unsupported dtype"); throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
} }
@ -259,7 +276,10 @@ class NCCLGroup : public GroupImpl {
throw std::runtime_error( throw std::runtime_error(
"[nccl] Input and output arrays must have the same size."); "[nccl] Input and output arrays must have the same size.");
} }
all_reduce_impl<float>(input, output, stream, ncclSum); detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
});
} }
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override { virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
@ -290,7 +310,10 @@ class NCCLGroup : public GroupImpl {
throw std::runtime_error( throw std::runtime_error(
"[nccl] Input and output arrays must have the same size."); "[nccl] Input and output arrays must have the same size.");
} }
all_reduce_impl<float>(input, output, stream, ncclMax); detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
});
} }
void all_min(const array& input, array& output, Stream stream) override { void all_min(const array& input, array& output, Stream stream) override {
@ -298,7 +321,10 @@ class NCCLGroup : public GroupImpl {
throw std::runtime_error( throw std::runtime_error(
"[nccl] Input and output arrays must have the same size."); "[nccl] Input and output arrays must have the same size.");
} }
all_reduce_impl<float>(input, output, stream, ncclMin); detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
});
} }
template <typename T> template <typename T>
@ -306,9 +332,8 @@ class NCCLGroup : public GroupImpl {
const array& input, const array& input,
array& output, array& output,
Stream stream, Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) { ncclRedOp_t op) {
ncclDataType_t dt = detail::datatype(input);
CHECK_NCCL(ncclAllReduce( CHECK_NCCL(ncclAllReduce(
input.data<T>(), input.data<T>(),
output.data<T>(), output.data<T>(),
@ -317,6 +342,7 @@ class NCCLGroup : public GroupImpl {
op, op,
comm_, comm_,
stream_)); stream_));
cudaStreamSynchronize(stream_);
} }
int rank_, size_; int rank_, size_;