mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 14:59:22 +08:00
dispatch types with dtype_utils
This commit is contained in:
parent
f7c11b965e
commit
cd53eb1ae3
@ -16,6 +16,7 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/dtype_utils.h"
|
||||||
|
|
||||||
namespace mlx::core::distributed::nccl {
|
namespace mlx::core::distributed::nccl {
|
||||||
|
|
||||||
@ -47,8 +48,47 @@ namespace mlx::core::distributed::nccl {
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
#define MLX_NCCL_TYPE_LIST(X) \
|
||||||
|
X(int8_t, ncclChar) \
|
||||||
|
X(uint8_t, ncclUint8) \
|
||||||
|
X(int32_t, ncclInt) \
|
||||||
|
X(uint32_t, ncclUint32) \
|
||||||
|
X(int64_t, ncclInt64) \
|
||||||
|
X(uint64_t, ncclUint64) \
|
||||||
|
X(float16_t, ncclHalf) \
|
||||||
|
X(bfloat16_t, ncclBfloat16) \
|
||||||
|
X(float, ncclFloat) \
|
||||||
|
X(double, ncclDouble)
|
||||||
|
|
||||||
|
template <class>
|
||||||
|
struct nccl_map {
|
||||||
|
static constexpr bool ok = false; // default: unsupported
|
||||||
|
};
|
||||||
|
|
||||||
|
#define MLX_DEF_NCCL_MAP(T, E) \
|
||||||
|
template <> \
|
||||||
|
struct nccl_map<T> { \
|
||||||
|
static constexpr bool ok = true; \
|
||||||
|
static constexpr ncclDataType_t value = E; \
|
||||||
|
};
|
||||||
|
|
||||||
|
MLX_NCCL_TYPE_LIST(MLX_DEF_NCCL_MAP)
|
||||||
|
#undef MLX_DEF_NCCL_MAP
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void dispatch_dtype(const array& arr, F&& f) {
|
||||||
|
dispatch_all_types(arr.dtype(), [&](auto type_tag) {
|
||||||
|
using T = MLX_GET_TYPE(type_tag);
|
||||||
|
if constexpr (nccl_map<T>::ok) {
|
||||||
|
f(type_tag, nccl_map<T>::value);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
inline void sendAll(int sock, const void* buf, size_t len) {
|
inline void sendAll(int sock, const void* buf, size_t len) {
|
||||||
const char* ptr = reinterpret_cast<const char*>(buf);
|
const char* ptr = reinterpret_cast<const char*>(buf);
|
||||||
while (len > 0) {
|
while (len > 0) {
|
||||||
@ -189,51 +229,6 @@ inline void bootstrap_unique_id(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct type_identity {
|
|
||||||
using type = T;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename F>
|
|
||||||
void dispatch_dtype(const array& arr, F&& f) {
|
|
||||||
switch (arr.dtype()) {
|
|
||||||
case bool_:
|
|
||||||
throw std::invalid_argument("[nccl] Boolean arrays not supported");
|
|
||||||
case int8:
|
|
||||||
f(type_identity<int8_t>{}, ncclChar);
|
|
||||||
break;
|
|
||||||
case uint8:
|
|
||||||
f(type_identity<uint8_t>{}, ncclUint8);
|
|
||||||
break;
|
|
||||||
case int32:
|
|
||||||
f(type_identity<int32_t>{}, ncclInt);
|
|
||||||
break;
|
|
||||||
case uint32:
|
|
||||||
f(type_identity<uint32_t>{}, ncclUint32);
|
|
||||||
break;
|
|
||||||
case int64:
|
|
||||||
f(type_identity<int64_t>{}, ncclInt64);
|
|
||||||
break;
|
|
||||||
case uint64:
|
|
||||||
f(type_identity<uint64_t>{}, ncclUint64);
|
|
||||||
break;
|
|
||||||
case float16:
|
|
||||||
f(type_identity<float16_t>{}, ncclHalf);
|
|
||||||
break;
|
|
||||||
case bfloat16:
|
|
||||||
f(type_identity<bfloat16_t>{}, ncclBfloat16);
|
|
||||||
break;
|
|
||||||
case float32:
|
|
||||||
f(type_identity<float>{}, ncclFloat);
|
|
||||||
break;
|
|
||||||
case float64:
|
|
||||||
f(type_identity<double>{}, ncclDouble);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::invalid_argument("[nccl] Unknown or unsupported dtype");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
using GroupImpl = mlx::core::distributed::detail::GroupImpl;
|
||||||
|
@ -975,6 +975,7 @@ class RingGroup : public GroupImpl {
|
|||||||
|
|
||||||
int rank_;
|
int rank_;
|
||||||
int size_;
|
int size_;
|
||||||
|
|
||||||
bool verbose_;
|
bool verbose_;
|
||||||
|
|
||||||
ThreadPool pool_;
|
ThreadPool pool_;
|
||||||
|
Loading…
Reference in New Issue
Block a user