mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
dispatch types with dtype_utils
This commit is contained in:
parent
f7c11b965e
commit
cd53eb1ae3
@ -16,6 +16,7 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
|
||||
namespace mlx::core::distributed::nccl {
|
||||
|
||||
@ -47,8 +48,47 @@ namespace mlx::core::distributed::nccl {
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define MLX_NCCL_TYPE_LIST(X) \
|
||||
X(int8_t, ncclChar) \
|
||||
X(uint8_t, ncclUint8) \
|
||||
X(int32_t, ncclInt) \
|
||||
X(uint32_t, ncclUint32) \
|
||||
X(int64_t, ncclInt64) \
|
||||
X(uint64_t, ncclUint64) \
|
||||
X(float16_t, ncclHalf) \
|
||||
X(bfloat16_t, ncclBfloat16) \
|
||||
X(float, ncclFloat) \
|
||||
X(double, ncclDouble)
|
||||
|
||||
template <class>
|
||||
struct nccl_map {
|
||||
static constexpr bool ok = false; // default: unsupported
|
||||
};
|
||||
|
||||
#define MLX_DEF_NCCL_MAP(T, E) \
|
||||
template <> \
|
||||
struct nccl_map<T> { \
|
||||
static constexpr bool ok = true; \
|
||||
static constexpr ncclDataType_t value = E; \
|
||||
};
|
||||
|
||||
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
|
||||
#undef MLX_DEF_NCCL_MAP
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F>
|
||||
void dispatch_dtype(const array& arr, F&& f) {
|
||||
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
|
||||
using T = MLX_GET_TYPE(type_tag);
|
||||
if constexpr (nccl_map<T>::ok) {
|
||||
f(type_tag, nccl_map<T>::value);
|
||||
} else {
|
||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||
while (len > 0) {
|
||||
@ -189,51 +229,6 @@ inline void bootstrap_unique_id(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct type_identity {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename F>
|
||||
void dispatch_dtype(const array& arr, F&& f) {
|
||||
switch (arr.dtype()) {
|
||||
case bool_:
|
||||
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
||||
case int8:
|
||||
f(type_identity<int8_t>{}, ncclChar);
|
||||
break;
|
||||
case uint8:
|
||||
f(type_identity<uint8_t>{}, ncclUint8);
|
||||
break;
|
||||
case int32:
|
||||
f(type_identity<int32_t>{}, ncclInt);
|
||||
break;
|
||||
case uint32:
|
||||
f(type_identity<uint32_t>{}, ncclUint32);
|
||||
break;
|
||||
case int64:
|
||||
f(type_identity<int64_t>{}, ncclInt64);
|
||||
break;
|
||||
case uint64:
|
||||
f(type_identity<uint64_t>{}, ncclUint64);
|
||||
break;
|
||||
case float16:
|
||||
f(type_identity<float16_t>{}, ncclHalf);
|
||||
break;
|
||||
case bfloat16:
|
||||
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
||||
break;
|
||||
case float32:
|
||||
f(type_identity<float>{}, ncclFloat);
|
||||
break;
|
||||
case float64:
|
||||
f(type_identity<double>{}, ncclDouble);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||
|
@ -975,6 +975,7 @@ class RingGroup : public GroupImpl {
|
||||
|
||||
int rank_;
|
||||
int size_;
|
||||
|
||||
bool verbose_;
|
||||
|
||||
ThreadPool pool_;
|
||||
|
Loading…
Reference in New Issue
Block a user