mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Gloo backend support
This commit is contained in:
parent
70ffaa50d2
commit
87b680766e
@ -1,16 +1,20 @@
|
|||||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
|
||||||
|
|
||||||
if (MLX_BUILD_CPU)
|
if(MLX_BUILD_CPU)
|
||||||
if (MLX_CUSTOM_DISTRIBUTED)
|
if(MLX_CUSTOM_DISTRIBUTED)
|
||||||
|
if(MLX_CUSTOM_DISTRIBUTED STREQUAL "gloo")
|
||||||
|
message(STATUS "Distributed: using gloo backend")
|
||||||
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/gloo)
|
||||||
|
else()
|
||||||
|
message(STATUS "Distributed: using sockets backend")
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets)
|
||||||
elseif (MPI_FOUND)
|
endif()
|
||||||
|
elseif(MPI_FOUND)
|
||||||
|
message(STATUS "Distributed: using MPI backend")
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
|
||||||
else()
|
else()
|
||||||
target_sources(
|
message(STATUS "Distributed: no support")
|
||||||
mlx
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
|
||||||
PRIVATE
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
|
|
||||||
)
|
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
25
mlx/distributed/gloo/CMakeLists.txt
Normal file
25
mlx/distributed/gloo/CMakeLists.txt
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
find_path(
|
||||||
|
GLOO_INCLUDE_DIR gloo/allreduce.h
|
||||||
|
PATHS ${GLOO_INC_DIR}
|
||||||
|
PATH_SUFFIXES include)
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
GLOO_LIBRARY gloo
|
||||||
|
PATHS ${GLOO_LIB_DIR}
|
||||||
|
PATH_SUFFIXES lib
|
||||||
|
HINTS GLOO)
|
||||||
|
|
||||||
|
find_library(
|
||||||
|
UV_LIBRARY uv
|
||||||
|
PATHS ${UV_LIB_DIR}
|
||||||
|
PATH_SUFFIXES lib
|
||||||
|
HINTS UV)
|
||||||
|
|
||||||
|
message(STATUS "GLOO LIB <${GLOO_LIBRARY}>")
|
||||||
|
message(STATUS "GLOO INC <${GLOO_INCLUDE_DIR}>")
|
||||||
|
message(STATUS "UV LIB <${UV_LIB_DIR}>")
|
||||||
|
|
||||||
|
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gloo.cpp)
|
||||||
|
target_link_libraries(mlx PUBLIC ${GLOO_LIBRARY})
|
||||||
|
target_link_libraries(mlx PUBLIC ${UV_LIBRARY})
|
||||||
|
target_include_directories(mlx PRIVATE ${GLOO_INCLUDE_DIR})
|
178
mlx/distributed/gloo/gloo.cpp
Normal file
178
mlx/distributed/gloo/gloo.cpp
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <chrono>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "mlx/backend/common/copy.h"
|
||||||
|
#include "mlx/distributed/distributed.h"
|
||||||
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
|
#include "mlx/io/threadpool.h"
|
||||||
|
|
||||||
|
#include "gloo/allreduce.h"
|
||||||
|
#include "gloo/math.h"
|
||||||
|
#include "gloo/mpi/context.h"
|
||||||
|
#include "gloo/transport/uv/device.h"
|
||||||
|
|
||||||
|
#define SWITCH_TYPE(x, ...) \
|
||||||
|
switch ((x).dtype()) { \
|
||||||
|
case bool_: { \
|
||||||
|
using T = bool; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case int8: { \
|
||||||
|
using T = int8_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case int16: { \
|
||||||
|
using T = int16_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case int32: { \
|
||||||
|
using T = int32_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case int64: { \
|
||||||
|
using T = int64_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case uint8: { \
|
||||||
|
using T = uint8_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case uint16: { \
|
||||||
|
using T = uint16_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case uint32: { \
|
||||||
|
using T = uint32_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case uint64: { \
|
||||||
|
using T = uint64_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case bfloat16: { \
|
||||||
|
using T = bfloat16_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case float16: { \
|
||||||
|
using T = float16_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case float32: { \
|
||||||
|
using T = float; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
case complex64: { \
|
||||||
|
using T = complex64_t; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} break; \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
array ensure_row_contiguous(const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy(arr, arr_copy, CopyType::General);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool is_available() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Group::rank() {
|
||||||
|
return std::static_pointer_cast<gloo::mpi::Context>(group_)->rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Group::size() {
|
||||||
|
return std::static_pointer_cast<gloo::mpi::Context>(group_)->size;
|
||||||
|
}
|
||||||
|
|
||||||
|
Group Group::split(int color, int key) {
|
||||||
|
throw std::runtime_error("split is NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
void Group::barrier() {
|
||||||
|
throw std::runtime_error("barrier is NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GlooCTX {
|
||||||
|
std::shared_ptr<gloo::mpi::Context> context;
|
||||||
|
std::shared_ptr<gloo::transport::Device> dev;
|
||||||
|
};
|
||||||
|
|
||||||
|
Group init(bool strict /* = false */) {
|
||||||
|
static std::shared_ptr<GlooCTX> gloo_ctx = nullptr;
|
||||||
|
|
||||||
|
if (gloo_ctx == nullptr) {
|
||||||
|
gloo_ctx = std::make_shared<GlooCTX>();
|
||||||
|
gloo_ctx->context = gloo::mpi::Context::createManaged();
|
||||||
|
gloo_ctx->dev = gloo::transport::uv::CreateDevice("localhost");
|
||||||
|
gloo_ctx->context->connectFullMesh(gloo_ctx->dev);
|
||||||
|
}
|
||||||
|
return Group(gloo_ctx->context);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
Stream communication_stream() {
|
||||||
|
static Stream comm_stream = new_stream(Device::cpu);
|
||||||
|
return comm_stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void all_reduce_sum(
|
||||||
|
std::shared_ptr<gloo::mpi::Context> context,
|
||||||
|
T* output,
|
||||||
|
T* input,
|
||||||
|
size_t len) {
|
||||||
|
gloo::AllreduceOptions opts_(context);
|
||||||
|
opts_.setInput(input, len);
|
||||||
|
opts_.setOutput(output, len);
|
||||||
|
opts_.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING);
|
||||||
|
opts_.setReduceFunction(
|
||||||
|
static_cast<void (*)(void*, const void*, const void*, size_t)>(
|
||||||
|
&gloo::sum<T>));
|
||||||
|
gloo::allreduce(opts_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_sum(Group group_, const array& input_, array& output) {
|
||||||
|
array input = ensure_row_contiguous(input_);
|
||||||
|
if (input.data<void>() != output.data<void>()) {
|
||||||
|
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
|
||||||
|
}
|
||||||
|
auto context =
|
||||||
|
std::static_pointer_cast<gloo::mpi::Context>(group_.raw_group());
|
||||||
|
SWITCH_TYPE(
|
||||||
|
output,
|
||||||
|
all_reduce_sum<T>(
|
||||||
|
context, output.data<T>(), input.data<T>(), input.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
void all_gather(Group group_, const array& input_, array& output) {
|
||||||
|
throw std::runtime_error("all_gather NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
void send(Group group_, const array& input_, int dst) {
|
||||||
|
throw std::runtime_error("send NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(Group group_, array& out, int src) {
|
||||||
|
throw std::runtime_error("recv NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
} // namespace mlx::core::distributed
|
Loading…
Reference in New Issue
Block a user