Helper function to parse types

This commit is contained in:
Anastasiia Filippova 2025-06-16 18:35:49 +02:00
parent f15a127900
commit e9fbdd20fb

View File

@ -11,6 +11,7 @@
#include <mutex>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "mlx/backend/cuda/device.h"
#include "mlx/distributed/distributed.h"
@ -187,30 +188,46 @@ inline void bootstrapUniqueId(
}
}
inline ncclDataType_t datatype(const array& arr) {
template <typename T>
struct type_identity {
using type = T;
};
template <typename F>
void dispatch_dtype(const array& arr, F&& f) {
switch (arr.dtype()) {
case bool_:
throw std::invalid_argument("[nccl] Boolean arrays not supported");
case int8:
return ncclChar;
f(type_identity<int8_t>{}, ncclChar);
break;
case uint8:
return ncclUint8;
f(type_identity<uint8_t>{}, ncclUint8);
break;
case int32:
return ncclInt;
f(type_identity<int32_t>{}, ncclInt);
break;
case uint32:
return ncclUint32;
f(type_identity<uint32_t>{}, ncclUint32);
break;
case int64:
return ncclInt64;
f(type_identity<int64_t>{}, ncclInt64);
break;
case uint64:
return ncclUint64;
f(type_identity<uint64_t>{}, ncclUint64);
break;
case float16:
return ncclHalf;
case float32:
return ncclFloat;
case float64:
return ncclDouble;
f(type_identity<float16_t>{}, ncclHalf);
break;
case bfloat16:
return ncclBfloat16;
f(type_identity<bfloat16_t>{}, ncclBfloat16);
break;
case float32:
f(type_identity<float>{}, ncclFloat);
break;
case float64:
f(type_identity<double>{}, ncclDouble);
break;
default:
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
}
@ -259,7 +276,10 @@ class NCCLGroup : public GroupImpl {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
all_reduce_impl<float>(input, output, stream, ncclSum);
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
});
}
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
@ -290,7 +310,10 @@ class NCCLGroup : public GroupImpl {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
all_reduce_impl<float>(input, output, stream, ncclMax);
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
});
}
void all_min(const array& input, array& output, Stream stream) override {
@ -298,7 +321,10 @@ class NCCLGroup : public GroupImpl {
throw std::runtime_error(
"[nccl] Input and output arrays must have the same size.");
}
all_reduce_impl<float>(input, output, stream, ncclMin);
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
});
}
template <typename T>
@ -306,9 +332,8 @@ class NCCLGroup : public GroupImpl {
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
ncclDataType_t dt = detail::datatype(input);
CHECK_NCCL(ncclAllReduce(
input.data<T>(),
output.data<T>(),
@ -317,6 +342,7 @@ class NCCLGroup : public GroupImpl {
op,
comm_,
stream_));
cudaStreamSynchronize(stream_);
}
int rank_, size_;