diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index 4e57aa1f3..3075e56fa 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -1,16 +1,20 @@ target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp) -if (MLX_BUILD_CPU) - if (MLX_CUSTOM_DISTRIBUTED) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets) - elseif (MPI_FOUND) - add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) +if(MLX_BUILD_CPU) + if(MLX_CUSTOM_DISTRIBUTED) + if(MLX_CUSTOM_DISTRIBUTED STREQUAL "gloo") + message(STATUS "Distributed: using gloo backend") + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/gloo) else() - target_sources( - mlx - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp - ) + message(STATUS "Distributed: using sockets backend") + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets) endif() + elseif(MPI_FOUND) + message(STATUS "Distributed: using MPI backend") + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) + else() + message(STATUS "Distributed: no support") + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp) + endif() endif() diff --git a/mlx/distributed/gloo/CMakeLists.txt b/mlx/distributed/gloo/CMakeLists.txt new file mode 100644 index 000000000..894efb02c --- /dev/null +++ b/mlx/distributed/gloo/CMakeLists.txt @@ -0,0 +1,25 @@ +find_path( + GLOO_INCLUDE_DIR gloo/allreduce.h + PATHS ${GLOO_INC_DIR} + PATH_SUFFIXES include) + +find_library( + GLOO_LIBRARY gloo + PATHS ${GLOO_LIB_DIR} + PATH_SUFFIXES lib + HINTS GLOO) + +find_library( + UV_LIBRARY uv + PATHS ${UV_LIB_DIR} + PATH_SUFFIXES lib + HINTS UV) + +message(STATUS "GLOO LIB <${GLOO_LIBRARY}>") +message(STATUS "GLOO INC <${GLOO_INCLUDE_DIR}>") +message(STATUS "UV LIB <${UV_LIB_DIR}>") + +target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gloo.cpp) +target_link_libraries(mlx PUBLIC ${GLOO_LIBRARY}) +target_link_libraries(mlx PUBLIC ${UV_LIBRARY}) +target_include_directories(mlx PRIVATE ${GLOO_INCLUDE_DIR}) diff --git a/mlx/distributed/gloo/gloo.cpp b/mlx/distributed/gloo/gloo.cpp new file mode 100644 index 000000000..6b478ac3b --- /dev/null +++ b/mlx/distributed/gloo/gloo.cpp @@ -0,0 +1,178 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" +#include "mlx/io/threadpool.h" + +#include "gloo/allreduce.h" +#include "gloo/math.h" +#include "gloo/mpi/context.h" +#include "gloo/transport/uv/device.h" + +#define SWITCH_TYPE(x, ...) \ + switch ((x).dtype()) { \ + case bool_: { \ + using T = bool; \ + __VA_ARGS__; \ + } break; \ + case int8: { \ + using T = int8_t; \ + __VA_ARGS__; \ + } break; \ + case int16: { \ + using T = int16_t; \ + __VA_ARGS__; \ + } break; \ + case int32: { \ + using T = int32_t; \ + __VA_ARGS__; \ + } break; \ + case int64: { \ + using T = int64_t; \ + __VA_ARGS__; \ + } break; \ + case uint8: { \ + using T = uint8_t; \ + __VA_ARGS__; \ + } break; \ + case uint16: { \ + using T = uint16_t; \ + __VA_ARGS__; \ + } break; \ + case uint32: { \ + using T = uint32_t; \ + __VA_ARGS__; \ + } break; \ + case uint64: { \ + using T = uint64_t; \ + __VA_ARGS__; \ + } break; \ + case bfloat16: { \ + using T = bfloat16_t; \ + __VA_ARGS__; \ + } break; \ + case float16: { \ + using T = float16_t; \ + __VA_ARGS__; \ + } break; \ + case float32: { \ + using T = float; \ + __VA_ARGS__; \ + } break; \ + case complex64: { \ + using T = complex64_t; \ + __VA_ARGS__; \ + } break; \ + } + +namespace mlx::core::distributed { + +namespace { +array ensure_row_contiguous(const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + return arr_copy; + } +} +} // namespace + +bool is_available() { + return true; +} + +int Group::rank() { + return std::static_pointer_cast(group_)->rank; +} + +int Group::size() { + return std::static_pointer_cast(group_)->size; +} + +Group Group::split(int color, int key) { + throw std::runtime_error("split is NYI"); +} + +void Group::barrier() { + throw std::runtime_error("barrier is NYI"); +} + +struct GlooCTX { + std::shared_ptr context; + std::shared_ptr dev; +}; + +Group init(bool strict /* = false */) { + static std::shared_ptr gloo_ctx = nullptr; + + if (gloo_ctx == nullptr) { + gloo_ctx = std::make_shared(); + gloo_ctx->context = gloo::mpi::Context::createManaged(); + gloo_ctx->dev = gloo::transport::uv::CreateDevice("localhost"); + gloo_ctx->context->connectFullMesh(gloo_ctx->dev); + } + return Group(gloo_ctx->context); +} + +namespace detail { + +Stream communication_stream() { + static Stream comm_stream = new_stream(Device::cpu); + return comm_stream; +} + +template +void all_reduce_sum( + std::shared_ptr context, + T* output, + T* input, + size_t len) { + gloo::AllreduceOptions opts_(context); + opts_.setInput(input, len); + opts_.setOutput(output, len); + opts_.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING); + opts_.setReduceFunction( + static_cast( + &gloo::sum)); + gloo::allreduce(opts_); +} + +void all_sum(Group group_, const array& input_, array& output) { + array input = ensure_row_contiguous(input_); + if (input.data() != output.data()) { + std::memcpy(output.data(), input.data(), input.nbytes()); + } + auto context = + std::static_pointer_cast(group_.raw_group()); + SWITCH_TYPE( + output, + all_reduce_sum( + context, output.data(), input.data(), input.size())); +} + +void all_gather(Group group_, const array& input_, array& output) { + throw std::runtime_error("all_gather NYI"); +} + +void send(Group group_, const array& input_, int dst) { + throw std::runtime_error("send NYI"); +} + +void recv(Group group_, array& out, int src) { + throw std::runtime_error("recv NYI"); +} + +} // namespace detail + +} // namespace mlx::core::distributed