mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Helper function to parse types
This commit is contained in:
parent
f15a127900
commit
e9fbdd20fb
@ -11,6 +11,7 @@
|
||||
#include <mutex>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
@ -187,30 +188,46 @@ inline void bootstrapUniqueId(
|
||||
}
|
||||
}
|
||||
|
||||
inline ncclDataType_t datatype(const array& arr) {
|
||||
template <typename T>
|
||||
struct type_identity {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
void dispatch_dtype(const array& arr, F&& f) {
|
||||
switch (arr.dtype()) {
|
||||
case bool_:
|
||||
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
||||
case int8:
|
||||
return ncclChar;
|
||||
f(type_identity<int8_t>{}, ncclChar);
|
||||
break;
|
||||
case uint8:
|
||||
return ncclUint8;
|
||||
f(type_identity<uint8_t>{}, ncclUint8);
|
||||
break;
|
||||
case int32:
|
||||
return ncclInt;
|
||||
f(type_identity<int32_t>{}, ncclInt);
|
||||
break;
|
||||
case uint32:
|
||||
return ncclUint32;
|
||||
f(type_identity<uint32_t>{}, ncclUint32);
|
||||
break;
|
||||
case int64:
|
||||
return ncclInt64;
|
||||
f(type_identity<int64_t>{}, ncclInt64);
|
||||
break;
|
||||
case uint64:
|
||||
return ncclUint64;
|
||||
f(type_identity<uint64_t>{}, ncclUint64);
|
||||
break;
|
||||
case float16:
|
||||
return ncclHalf;
|
||||
case float32:
|
||||
return ncclFloat;
|
||||
case float64:
|
||||
return ncclDouble;
|
||||
f(type_identity<float16_t>{}, ncclHalf);
|
||||
break;
|
||||
case bfloat16:
|
||||
return ncclBfloat16;
|
||||
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
||||
break;
|
||||
case float32:
|
||||
f(type_identity<float>{}, ncclFloat);
|
||||
break;
|
||||
case float64:
|
||||
f(type_identity<double>{}, ncclDouble);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||
}
|
||||
@ -259,7 +276,10 @@ class NCCLGroup : public GroupImpl {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
all_reduce_impl<float>(input, output, stream, ncclSum);
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||
});
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||
@ -290,7 +310,10 @@ class NCCLGroup : public GroupImpl {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
all_reduce_impl<float>(input, output, stream, ncclMax);
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||
});
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
@ -298,7 +321,10 @@ class NCCLGroup : public GroupImpl {
|
||||
throw std::runtime_error(
|
||||
"[nccl] Input and output arrays must have the same size.");
|
||||
}
|
||||
all_reduce_impl<float>(input, output, stream, ncclMin);
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -306,9 +332,8 @@ class NCCLGroup : public GroupImpl {
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream,
|
||||
ncclDataType_t dt,
|
||||
ncclRedOp_t op) {
|
||||
ncclDataType_t dt = detail::datatype(input);
|
||||
|
||||
CHECK_NCCL(ncclAllReduce(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
@ -317,6 +342,7 @@ class NCCLGroup : public GroupImpl {
|
||||
op,
|
||||
comm_,
|
||||
stream_));
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
|
||||
int rank_, size_;
|
||||
|
Loading…
Reference in New Issue
Block a user