mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Gloo backend support
This commit is contained in:
		@@ -1,16 +1,20 @@
 | 
			
		||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
 | 
			
		||||
                           ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp)
 | 
			
		||||
 | 
			
		||||
if (MLX_BUILD_CPU)
 | 
			
		||||
    if (MLX_CUSTOM_DISTRIBUTED)
 | 
			
		||||
      add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets)
 | 
			
		||||
    elseif (MPI_FOUND)
 | 
			
		||||
      add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
 | 
			
		||||
if(MLX_BUILD_CPU)
 | 
			
		||||
  if(MLX_CUSTOM_DISTRIBUTED)
 | 
			
		||||
    if(MLX_CUSTOM_DISTRIBUTED STREQUAL "gloo")
 | 
			
		||||
      message(STATUS "Distributed: using gloo backend")
 | 
			
		||||
      add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/gloo)
 | 
			
		||||
    else()
 | 
			
		||||
      target_sources(
 | 
			
		||||
        mlx
 | 
			
		||||
        PRIVATE
 | 
			
		||||
        ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
 | 
			
		||||
      )
 | 
			
		||||
      message(STATUS "Distributed: using sockets backend")
 | 
			
		||||
      add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/sockets)
 | 
			
		||||
    endif()
 | 
			
		||||
  elseif(MPI_FOUND)
 | 
			
		||||
    message(STATUS "Distributed: using MPI backend")
 | 
			
		||||
    add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
 | 
			
		||||
  else()
 | 
			
		||||
    message(STATUS "Distributed: no support")
 | 
			
		||||
    target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp)
 | 
			
		||||
  endif()
 | 
			
		||||
endif()
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										25
									
								
								mlx/distributed/gloo/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								mlx/distributed/gloo/CMakeLists.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
find_path(
 | 
			
		||||
  GLOO_INCLUDE_DIR gloo/allreduce.h
 | 
			
		||||
  PATHS ${GLOO_INC_DIR}
 | 
			
		||||
  PATH_SUFFIXES include)
 | 
			
		||||
 | 
			
		||||
find_library(
 | 
			
		||||
  GLOO_LIBRARY gloo
 | 
			
		||||
  PATHS ${GLOO_LIB_DIR}
 | 
			
		||||
  PATH_SUFFIXES lib
 | 
			
		||||
  HINTS GLOO)
 | 
			
		||||
 | 
			
		||||
find_library(
 | 
			
		||||
  UV_LIBRARY uv
 | 
			
		||||
  PATHS ${UV_LIB_DIR}
 | 
			
		||||
  PATH_SUFFIXES lib
 | 
			
		||||
  HINTS UV)
 | 
			
		||||
 | 
			
		||||
message(STATUS "GLOO LIB <${GLOO_LIBRARY}>")
 | 
			
		||||
message(STATUS "GLOO INC <${GLOO_INCLUDE_DIR}>")
 | 
			
		||||
message(STATUS "UV   LIB <${UV_LIB_DIR}>")
 | 
			
		||||
 | 
			
		||||
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gloo.cpp)
 | 
			
		||||
target_link_libraries(mlx PUBLIC ${GLOO_LIBRARY})
 | 
			
		||||
target_link_libraries(mlx PUBLIC ${UV_LIBRARY})
 | 
			
		||||
target_include_directories(mlx PRIVATE ${GLOO_INCLUDE_DIR})
 | 
			
		||||
							
								
								
									
										178
									
								
								mlx/distributed/gloo/gloo.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										178
									
								
								mlx/distributed/gloo/gloo.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,178 @@
 | 
			
		||||
// Copyright © 2024 Apple Inc.
 | 
			
		||||
 | 
			
		||||
#include <unistd.h>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <fstream>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
#include <thread>
 | 
			
		||||
 | 
			
		||||
#include "mlx/backend/common/copy.h"
 | 
			
		||||
#include "mlx/distributed/distributed.h"
 | 
			
		||||
#include "mlx/distributed/distributed_impl.h"
 | 
			
		||||
#include "mlx/io/threadpool.h"
 | 
			
		||||
 | 
			
		||||
#include "gloo/allreduce.h"
 | 
			
		||||
#include "gloo/math.h"
 | 
			
		||||
#include "gloo/mpi/context.h"
 | 
			
		||||
#include "gloo/transport/uv/device.h"
 | 
			
		||||
 | 
			
		||||
#define SWITCH_TYPE(x, ...)  \
 | 
			
		||||
  switch ((x).dtype()) {     \
 | 
			
		||||
    case bool_: {            \
 | 
			
		||||
      using T = bool;        \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case int8: {             \
 | 
			
		||||
      using T = int8_t;      \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case int16: {            \
 | 
			
		||||
      using T = int16_t;     \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case int32: {            \
 | 
			
		||||
      using T = int32_t;     \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case int64: {            \
 | 
			
		||||
      using T = int64_t;     \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case uint8: {            \
 | 
			
		||||
      using T = uint8_t;     \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case uint16: {           \
 | 
			
		||||
      using T = uint16_t;    \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case uint32: {           \
 | 
			
		||||
      using T = uint32_t;    \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case uint64: {           \
 | 
			
		||||
      using T = uint64_t;    \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case bfloat16: {         \
 | 
			
		||||
      using T = bfloat16_t;  \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case float16: {          \
 | 
			
		||||
      using T = float16_t;   \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case float32: {          \
 | 
			
		||||
      using T = float;       \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
    case complex64: {        \
 | 
			
		||||
      using T = complex64_t; \
 | 
			
		||||
      __VA_ARGS__;           \
 | 
			
		||||
    } break;                 \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
namespace mlx::core::distributed {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
array ensure_row_contiguous(const array& arr) {
 | 
			
		||||
  if (arr.flags().row_contiguous) {
 | 
			
		||||
    return arr;
 | 
			
		||||
  } else {
 | 
			
		||||
    array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
 | 
			
		||||
    copy(arr, arr_copy, CopyType::General);
 | 
			
		||||
    return arr_copy;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
bool is_available() {
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int Group::rank() {
 | 
			
		||||
  return std::static_pointer_cast<gloo::mpi::Context>(group_)->rank;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int Group::size() {
 | 
			
		||||
  return std::static_pointer_cast<gloo::mpi::Context>(group_)->size;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Group Group::split(int color, int key) {
 | 
			
		||||
  throw std::runtime_error("split is NYI");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Group::barrier() {
 | 
			
		||||
  throw std::runtime_error("barrier is NYI");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct GlooCTX {
 | 
			
		||||
  std::shared_ptr<gloo::mpi::Context> context;
 | 
			
		||||
  std::shared_ptr<gloo::transport::Device> dev;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
Group init(bool strict /* = false */) {
 | 
			
		||||
  static std::shared_ptr<GlooCTX> gloo_ctx = nullptr;
 | 
			
		||||
 | 
			
		||||
  if (gloo_ctx == nullptr) {
 | 
			
		||||
    gloo_ctx = std::make_shared<GlooCTX>();
 | 
			
		||||
    gloo_ctx->context = gloo::mpi::Context::createManaged();
 | 
			
		||||
    gloo_ctx->dev = gloo::transport::uv::CreateDevice("localhost");
 | 
			
		||||
    gloo_ctx->context->connectFullMesh(gloo_ctx->dev);
 | 
			
		||||
  }
 | 
			
		||||
  return Group(gloo_ctx->context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace detail {
 | 
			
		||||
 | 
			
		||||
Stream communication_stream() {
 | 
			
		||||
  static Stream comm_stream = new_stream(Device::cpu);
 | 
			
		||||
  return comm_stream;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void all_reduce_sum(
 | 
			
		||||
    std::shared_ptr<gloo::mpi::Context> context,
 | 
			
		||||
    T* output,
 | 
			
		||||
    T* input,
 | 
			
		||||
    size_t len) {
 | 
			
		||||
  gloo::AllreduceOptions opts_(context);
 | 
			
		||||
  opts_.setInput(input, len);
 | 
			
		||||
  opts_.setOutput(output, len);
 | 
			
		||||
  opts_.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING);
 | 
			
		||||
  opts_.setReduceFunction(
 | 
			
		||||
      static_cast<void (*)(void*, const void*, const void*, size_t)>(
 | 
			
		||||
          &gloo::sum<T>));
 | 
			
		||||
  gloo::allreduce(opts_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void all_sum(Group group_, const array& input_, array& output) {
 | 
			
		||||
  array input = ensure_row_contiguous(input_);
 | 
			
		||||
  if (input.data<void>() != output.data<void>()) {
 | 
			
		||||
    std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
 | 
			
		||||
  }
 | 
			
		||||
  auto context =
 | 
			
		||||
      std::static_pointer_cast<gloo::mpi::Context>(group_.raw_group());
 | 
			
		||||
  SWITCH_TYPE(
 | 
			
		||||
      output,
 | 
			
		||||
      all_reduce_sum<T>(
 | 
			
		||||
          context, output.data<T>(), input.data<T>(), input.size()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void all_gather(Group group_, const array& input_, array& output) {
 | 
			
		||||
  throw std::runtime_error("all_gather NYI");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void send(Group group_, const array& input_, int dst) {
 | 
			
		||||
  throw std::runtime_error("send NYI");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void recv(Group group_, array& out, int src) {
 | 
			
		||||
  throw std::runtime_error("recv NYI");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
} // namespace mlx::core::distributed
 | 
			
		||||
		Reference in New Issue
	
	Block a user