mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Helper function to parse types
This commit is contained in:
parent
f15a127900
commit
e9fbdd20fb
@ -11,6 +11,7 @@
|
|||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
@ -187,30 +188,46 @@ inline void bootstrapUniqueId(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline ncclDataType_t datatype(const array& arr) {
|
template <typename T>
|
||||||
|
struct type_identity {
|
||||||
|
using type = T;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_dtype(const array& arr, F&& f) {
|
||||||
switch (arr.dtype()) {
|
switch (arr.dtype()) {
|
||||||
case bool_:
|
case bool_:
|
||||||
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
||||||
case int8:
|
case int8:
|
||||||
return ncclChar;
|
f(type_identity<int8_t>{}, ncclChar);
|
||||||
|
break;
|
||||||
case uint8:
|
case uint8:
|
||||||
return ncclUint8;
|
f(type_identity<uint8_t>{}, ncclUint8);
|
||||||
|
break;
|
||||||
case int32:
|
case int32:
|
||||||
return ncclInt;
|
f(type_identity<int32_t>{}, ncclInt);
|
||||||
|
break;
|
||||||
case uint32:
|
case uint32:
|
||||||
return ncclUint32;
|
f(type_identity<uint32_t>{}, ncclUint32);
|
||||||
|
break;
|
||||||
case int64:
|
case int64:
|
||||||
return ncclInt64;
|
f(type_identity<int64_t>{}, ncclInt64);
|
||||||
|
break;
|
||||||
case uint64:
|
case uint64:
|
||||||
return ncclUint64;
|
f(type_identity<uint64_t>{}, ncclUint64);
|
||||||
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
return ncclHalf;
|
f(type_identity<float16_t>{}, ncclHalf);
|
||||||
case float32:
|
break;
|
||||||
return ncclFloat;
|
|
||||||
case float64:
|
|
||||||
return ncclDouble;
|
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
return ncclBfloat16;
|
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
||||||
|
break;
|
||||||
|
case float32:
|
||||||
|
f(type_identity<float>{}, ncclFloat);
|
||||||
|
break;
|
||||||
|
case float64:
|
||||||
|
f(type_identity<double>{}, ncclDouble);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||||
}
|
}
|
||||||
@ -259,7 +276,10 @@ class NCCLGroup : public GroupImpl {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[nccl] Input and output arrays must have the same size.");
|
"[nccl] Input and output arrays must have the same size.");
|
||||||
}
|
}
|
||||||
all_reduce_impl<float>(input, output, stream, ncclSum);
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclSum);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
@ -290,7 +310,10 @@ class NCCLGroup : public GroupImpl {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[nccl] Input and output arrays must have the same size.");
|
"[nccl] Input and output arrays must have the same size.");
|
||||||
}
|
}
|
||||||
all_reduce_impl<float>(input, output, stream, ncclMax);
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_min(const array& input, array& output, Stream stream) override {
|
void all_min(const array& input, array& output, Stream stream) override {
|
||||||
@ -298,7 +321,10 @@ class NCCLGroup : public GroupImpl {
|
|||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"[nccl] Input and output arrays must have the same size.");
|
"[nccl] Input and output arrays must have the same size.");
|
||||||
}
|
}
|
||||||
all_reduce_impl<float>(input, output, stream, ncclMin);
|
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||||
|
using T = typename decltype(type_tag)::type;
|
||||||
|
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -306,9 +332,8 @@ class NCCLGroup : public GroupImpl {
|
|||||||
const array& input,
|
const array& input,
|
||||||
array& output,
|
array& output,
|
||||||
Stream stream,
|
Stream stream,
|
||||||
|
ncclDataType_t dt,
|
||||||
ncclRedOp_t op) {
|
ncclRedOp_t op) {
|
||||||
ncclDataType_t dt = detail::datatype(input);
|
|
||||||
|
|
||||||
CHECK_NCCL(ncclAllReduce(
|
CHECK_NCCL(ncclAllReduce(
|
||||||
input.data<T>(),
|
input.data<T>(),
|
||||||
output.data<T>(),
|
output.data<T>(),
|
||||||
@ -317,6 +342,7 @@ class NCCLGroup : public GroupImpl {
|
|||||||
op,
|
op,
|
||||||
comm_,
|
comm_,
|
||||||
stream_));
|
stream_));
|
||||||
|
cudaStreamSynchronize(stream_);
|
||||||
}
|
}
|
||||||
|
|
||||||
int rank_, size_;
|
int rank_, size_;
|
||||||
|
Loading…
Reference in New Issue
Block a user