diff --git a/.circleci/config.yml b/.circleci/config.yml index 9965c98e4..0f528a33d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -71,6 +71,7 @@ jobs: name: Install dependencies command: | brew install python@3.8 + brew install openmpi python3.8 -m venv env source env/bin/activate pip install --upgrade pip @@ -96,6 +97,7 @@ jobs: source env/bin/activate LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu + mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py - run: name: Build example extension command: | diff --git a/CMakeLists.txt b/CMakeLists.txt index be6ebe1c3..dc7412b38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -167,6 +167,11 @@ else() set(MLX_BUILD_ACCELERATE OFF) endif() +find_package(MPI) +if (MPI_FOUND) + target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH}) +endif() + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) target_include_directories( diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index cabb723fa..1b3969c97 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -9,3 +9,4 @@ build_example(tutorial.cpp) build_example(linear_regression.cpp) build_example(logistic_regression.cpp) build_example(metal_capture.cpp) +build_example(distributed.cpp) diff --git a/examples/cpp/distributed.cpp b/examples/cpp/distributed.cpp new file mode 100644 index 000000000..283751a62 --- /dev/null +++ b/examples/cpp/distributed.cpp @@ -0,0 +1,22 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/mlx.h" + +using namespace mlx::core; + +int main() { + if (!distributed::is_available()) { + std::cout << "No communication backend found" << std::endl; + return 1; + } + + auto global_group = distributed::init(); + std::cout << global_group.rank() << " / " << global_group.size() << std::endl; + + array x = ones({10}); + array out = distributed::all_reduce_sum(x, global_group); + + std::cout << out << std::endl; +} diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index c53c3ec7d..14c24896d 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -25,6 +25,7 @@ else() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) endif() +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) if (MLX_BUILD_ACCELERATE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt new file mode 100644 index 000000000..d7521a365 --- /dev/null +++ b/mlx/distributed/CMakeLists.txt @@ -0,0 +1,16 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp +) + +if (MPI_FOUND AND MLX_BUILD_CPU) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi) +else() + target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp + ) +endif() diff --git a/mlx/distributed/distributed.h b/mlx/distributed/distributed.h new file mode 100644 index 000000000..cad75b396 --- /dev/null +++ b/mlx/distributed/distributed.h @@ -0,0 +1,62 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" + +namespace mlx::core::distributed { + +/* Check if a communication backend is available */ +bool is_available(); + +/** + * A distributed::Group represents a group of independent mlx processes that + * can communicate. We must also be able to create sub-groups from a group in + * order to define more granular communication. + */ +struct Group { + Group(std::shared_ptr group) : group_(group) {} + + int rank(); + int size(); + + /** + * Split the group according to the provided color. Namely processes that use + * the same color will go to the same group. + * + * The key defines the rank of the processes in the new group. The smaller + * the key the smaller the rank. If the provided key is negative, then the + * rank in the current group is used. + */ + Group split(int color, int key = -1); + + const std::shared_ptr& raw_group() { + return group_; + } + + private: + std::shared_ptr group_{nullptr}; +}; + +/** + * Initialize the distributed backend and return the group containing all + * discoverable processes. + */ +Group init(); + +namespace detail { + +/* Return the communication stream. */ +Stream communication_stream(); + +/* Perform an all reduce sum operation */ +void all_reduce_sum(Group group, const array& input, array& output); + +/* Perform an all reduce sum operation */ +void all_gather(Group group, const array& input, array& output); + +} // namespace detail + +} // namespace mlx::core::distributed diff --git a/mlx/distributed/mpi/CMakeLists.txt b/mlx/distributed/mpi/CMakeLists.txt new file mode 100644 index 000000000..3caca724c --- /dev/null +++ b/mlx/distributed/mpi/CMakeLists.txt @@ -0,0 +1,5 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp +) diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp new file mode 100644 index 000000000..3d1818195 --- /dev/null +++ b/mlx/distributed/mpi/mpi.cpp @@ -0,0 +1,283 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/distributed/distributed.h" +#include "mlx/scheduler.h" + +#define LOAD_SYMBOL(symbol, variable) \ + { \ + variable = (decltype(variable))dlsym(libmpi_handle_, #symbol); \ + char* error = dlerror(); \ + if (error != nullptr) { \ + libmpi_handle_ = nullptr; \ + return; \ + } \ + } + +namespace mlx::core::distributed { + +namespace { + +array ensure_row_contiguous(const array& arr) { + if (arr.flags().row_contiguous) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + return arr_copy; + } +} + +struct MPIWrapper { + MPIWrapper() { + libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL); + if (libmpi_handle_ == nullptr) { + return; + } + + // API + LOAD_SYMBOL(MPI_Init, init); + LOAD_SYMBOL(MPI_Finalize, finalize); + LOAD_SYMBOL(MPI_Comm_rank, rank); + LOAD_SYMBOL(MPI_Comm_size, size); + LOAD_SYMBOL(MPI_Comm_split, comm_split); + LOAD_SYMBOL(MPI_Comm_free, comm_free); + LOAD_SYMBOL(MPI_Allreduce, all_reduce); + LOAD_SYMBOL(MPI_Allgather, all_gather); + + // Objects + LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_); + + // Ops + LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_); + + // Datatypes + LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_); + LOAD_SYMBOL(ompi_mpi_int8_t, mpi_int8_); + LOAD_SYMBOL(ompi_mpi_uint8_t, mpi_uint8_); + LOAD_SYMBOL(ompi_mpi_int16_t, mpi_int16_); + LOAD_SYMBOL(ompi_mpi_uint16_t, mpi_uint16_); + LOAD_SYMBOL(ompi_mpi_int32_t, mpi_int32_); + LOAD_SYMBOL(ompi_mpi_uint32_t, mpi_uint32_); + LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_); + LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_); + LOAD_SYMBOL(ompi_mpi_float, mpi_float_); + LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_); + } + + bool is_available() { + return libmpi_handle_ != nullptr; + } + + bool init_safe() { + if (!is_available()) { + return false; + } + return init(nullptr, nullptr) == MPI_SUCCESS; + } + + void finalize_safe() { + if (is_available()) { + finalize(); + } + } + + MPI_Comm world() { + return comm_world_; + } + + MPI_Datatype datatype(const array& arr) { + switch (arr.dtype()) { + case bool_: + return mpi_bool_; + case int8: + return mpi_int8_; + case uint8: + return mpi_uint8_; + case int16: + return mpi_int16_; + case uint16: + return mpi_uint16_; + case int32: + return mpi_int32_; + case uint32: + return mpi_uint32_; + case int64: + return mpi_int64_; + case uint64: + return mpi_uint64_; + case float32: + return mpi_float_; + case complex64: + return mpi_complex_; + case float16: + case bfloat16: + throw std::runtime_error("MPI doesn't support 16-bit floats"); + } + } + + MPI_Op op_sum() { + return op_sum_; + } + + void* libmpi_handle_; + + // API + int (*init)(int*, char***); + int (*finalize)(); + int (*rank)(MPI_Comm, int*); + int (*size)(MPI_Comm, int*); + int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm); + int (*all_gather)( + const void*, + int, + MPI_Datatype, + void*, + int, + MPI_Datatype, + MPI_Comm); + int (*comm_split)(MPI_Comm, int, int, MPI_Comm*); + int (*comm_free)(MPI_Comm*); + + // Objects + MPI_Comm comm_world_; + + // Ops + MPI_Op op_sum_; + + // Datatypes + MPI_Datatype mpi_bool_; + MPI_Datatype mpi_int8_; + MPI_Datatype mpi_uint8_; + MPI_Datatype mpi_int16_; + MPI_Datatype mpi_uint16_; + MPI_Datatype mpi_int32_; + MPI_Datatype mpi_uint32_; + MPI_Datatype mpi_int64_; + MPI_Datatype mpi_uint64_; + MPI_Datatype mpi_float_; + MPI_Datatype mpi_complex_; +}; + +MPIWrapper& mpi() { + static MPIWrapper wrapper; + return wrapper; +} + +struct MPIGroupImpl { + MPIGroupImpl(MPI_Comm comm, bool global) + : comm_(comm), global_(global), rank_(-1), size_(-1) {} + ~MPIGroupImpl() { + if (global_) { + mpi().finalize_safe(); + } else { + mpi().comm_free(&comm_); + } + } + + MPI_Comm comm() { + return comm_; + } + + int rank() { + if (rank_ < 0) { + mpi().rank(comm_, &rank_); + } + return rank_; + } + + int size() { + if (size_ < 0) { + mpi().size(comm_, &size_); + } + return size_; + } + + private: + MPI_Comm comm_; + bool global_; + int rank_; + int size_; +}; + +MPI_Comm to_comm(Group& group) { + return std::static_pointer_cast(group.raw_group())->comm(); +} + +} // namespace + +int Group::rank() { + return std::static_pointer_cast(group_)->rank(); +} + +int Group::size() { + return std::static_pointer_cast(group_)->size(); +} + +Group Group::split(int color, int key) { + auto mpi_group = std::static_pointer_cast(group_); + + key = (key < 0) ? rank() : key; + + MPI_Comm new_comm; + int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm); + if (result != MPI_SUCCESS) { + throw std::runtime_error("MPI could not split this group"); + } + + return Group(std::make_shared(new_comm, false)); +} + +bool is_available() { + return mpi().is_available(); +} + +Group init() { + static std::shared_ptr global_group = nullptr; + + if (global_group == nullptr) { + if (!mpi().init_safe()) { + throw std::runtime_error("Cannot initialize MPI"); + } + global_group = std::make_shared(mpi().world(), true); + } + + return Group(global_group); +} + +namespace detail { + +Stream communication_stream() { + static Stream comm_stream = new_stream(Device::cpu); + return comm_stream; +} + +void all_reduce_sum(Group group, const array& input_, array& output) { + array input = ensure_row_contiguous(input_); + mpi().all_reduce( + input.data(), + output.data(), + input.size(), + mpi().datatype(input), + mpi().op_sum(), + to_comm(group)); +} + +void all_gather(Group group, const array& input_, array& output) { + array input = ensure_row_contiguous(input_); + mpi().all_gather( + input.data(), + input.size(), + mpi().datatype(input), + output.data(), + input.size(), + mpi().datatype(output), + to_comm(group)); +} + +} // namespace detail + +} // namespace mlx::core::distributed diff --git a/mlx/distributed/no_distributed.cpp b/mlx/distributed/no_distributed.cpp new file mode 100644 index 000000000..d85428496 --- /dev/null +++ b/mlx/distributed/no_distributed.cpp @@ -0,0 +1,39 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed { + +int Group::rank() { + return 0; +} + +int Group::size() { + return 1; +} + +Group Group::split(int color, int key) { + throw std::runtime_error("Cannot split the distributed group further"); +} + +bool is_available() { + return false; +} + +Group init() { + return Group(nullptr); +} + +namespace detail { + +Stream communication_stream() { + static Stream comm_stream = new_stream(Device::cpu); + return comm_stream; +} + +void all_reduce_sum(Group group, const array& input, array& output) {} +void all_gather(Group group, const array& input, array& output) {} + +} // namespace detail + +} // namespace mlx::core::distributed diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp new file mode 100644 index 000000000..69cf196cb --- /dev/null +++ b/mlx/distributed/ops.cpp @@ -0,0 +1,54 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/ops.h" +#include "mlx/distributed/primitives.h" + +namespace mlx::core::distributed { + +namespace { + +Group to_group(std::optional group) { + if (group.has_value()) { + return group.value(); + } else { + return distributed::init(); + } +} + +} // namespace + +array all_reduce_sum(const array& x, std::optional group_) { + auto group = to_group(group_); + + if (group.size() == 1) { + return x; + } + + return array( + x.shape(), + x.dtype(), + std::make_shared(group, AllReduce::Sum), + {x}); +} + +array all_gather(const array& x, std::optional group_) { + auto group = to_group(group_); + + if (group.size() == 1) { + return x; + } + + auto result_shape = x.shape(); + if (result_shape.size() == 0) { + result_shape.push_back(group.size()); + } else { + result_shape[0] *= group.size(); + } + return array( + std::move(result_shape), + x.dtype(), + std::make_shared(group), + {x}); +} + +} // namespace mlx::core::distributed diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h new file mode 100644 index 000000000..1afe8dcc8 --- /dev/null +++ b/mlx/distributed/ops.h @@ -0,0 +1,14 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed { + +array all_reduce_sum(const array& x, std::optional group = std::nullopt); +array all_gather(const array& x, std::optional group = std::nullopt); + +} // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp new file mode 100644 index 000000000..b20fde605 --- /dev/null +++ b/mlx/distributed/primitives.cpp @@ -0,0 +1,98 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/distributed/ops.h" +#include "mlx/distributed/primitives.h" +#include "mlx/ops.h" + +namespace mlx::core::distributed { + +void AllReduce::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + + switch (reduce_type_) { + case Sum: + distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]); + break; + default: + throw std::runtime_error("Only all reduce sum is supported for now"); + } +} + +std::pair, std::vector> AllReduce::vmap( + const std::vector& inputs, + const std::vector& axes) { + switch (reduce_type_) { + case Sum: + return {{all_reduce_sum(inputs[0], group())}, axes}; + default: + throw std::runtime_error("Only all reduce sum is supported for now"); + } +} + +std::vector AllReduce::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + switch (reduce_type_) { + case Sum: + return {all_reduce_sum(tangents[0], group())}; + default: + throw std::runtime_error("Only all reduce sum is supported for now"); + } +} + +std::vector AllReduce::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + return cotangents; +} + +void AllGather::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + + outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes())); + + distributed::detail::all_gather(group(), inputs[0], outputs[0]); +} + +std::pair, std::vector> AllGather::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {{all_gather(inputs[0], group())}, axes}; +} + +std::vector AllGather::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) { + return {all_gather(tangents[0], group())}; +} + +std::vector AllGather::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + auto g = group(); + std::vector starts(primals[0].ndim(), 0); + auto stops = primals[0].shape(); + starts[0] = g.rank() * stops[0]; + stops[0] += starts[0]; + return {slice(cotangents[0], starts, stops)}; +} + +} // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h new file mode 100644 index 000000000..8107f4b12 --- /dev/null +++ b/mlx/distributed/primitives.h @@ -0,0 +1,100 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/distributed/distributed.h" +#include "mlx/primitives.h" + +namespace mlx::core::distributed { + +class DistPrimitive : public Primitive { + public: + DistPrimitive(Group group) + : Primitive(detail::communication_stream()), group_(group) {} + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error( + "Communication primitives cannot be run on the GPU"); + } + + const Group& group() const { + return group_; + } + + private: + Group group_; +}; + +class AllReduce : public DistPrimitive { + public: + enum ReduceType { And, Or, Sum, Prod, Min, Max }; + + AllReduce(Group group, ReduceType reduce_type) + : DistPrimitive(group), reduce_type_(reduce_type) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + void print(std::ostream& os) override { + switch (reduce_type_) { + case And: + os << "And"; + case Or: + os << "And"; + break; + case Sum: + os << "Sum"; + break; + case Prod: + os << "Prod"; + break; + case Min: + os << "Min"; + break; + case Max: + os << "Max"; + break; + } + os << " AllReduce"; + } + + private: + ReduceType reduce_type_; +}; + +class AllGather : public DistPrimitive { + public: + AllGather(Group group) : DistPrimitive(group) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_PRINT(AllGather); +}; + +} // namespace mlx::core::distributed diff --git a/mlx/mlx.h b/mlx/mlx.h index 1963a4c50..d8fe150ed 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,8 @@ #include "mlx/backend/metal/metal.h" #include "mlx/compile.h" #include "mlx/device.h" +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/ops.h" #include "mlx/fast.h" #include "mlx/fft.h" #include "mlx/io.h" diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index ae0531385..c74ce9c95 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -6,6 +6,7 @@ nanobind_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp new file mode 100644 index 000000000..069b5a885 --- /dev/null +++ b/python/src/distributed.cpp @@ -0,0 +1,107 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include + +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/ops.h" + +namespace nb = nanobind; +using namespace nb::literals; +using namespace mlx::core; + +void init_distributed(nb::module_& parent_module) { + auto m = parent_module.def_submodule( + "distributed", "mlx.core.distributed: Communication operations"); + + nb::class_( + m, + "Group", + R"pbcopy( + An :class:`mlx.core.distributed.Group` represents a group of independent mlx + processes that can communicate. + )pbcopy") + .def("rank", &distributed::Group::rank, "Get the rank of this process") + .def("size", &distributed::Group::size, "Get the size of the group") + .def( + "split", + &distributed::Group::split, + "color"_a, + "key"_a = -1, + nb::sig("def split(self, color: int, key: int = -1) -> Group"), + R"pbdoc( + Split the group to subgroups based on the provided color. + + Processes that use the same color go to the same group. The ``key`` + argument defines the rank in the new group. The smaller the key the + smaller the rank. If the key is negative then the rank in the + current group is used. + + Args: + color (int): A value to group processes into subgroups. + key (int, optional): A key to optionally change the rank ordering + of the processes. + )pbdoc"); + + m.def( + "is_available", + &distributed::is_available, + R"pbdoc( + Check if a communication backend is available. + )pbdoc"); + + m.def( + "init", + &distributed::init, + R"pbdoc( + Initialize the communication backend and create the global communication group. + )pbdoc"); + + m.def( + "all_reduce_sum", + &distributed::all_reduce_sum, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + nb::sig( + "def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"), + R"pbdoc( + All reduce sum. + + Sum the ``x`` arrays from all processes in the group. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + reduction. If set to ``None`` the global group is used. Default: + ``None``. + + Returns: + array: The sum of all ``x`` arrays. + )pbdoc"); + + m.def( + "all_gather", + &distributed::all_gather, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + nb::sig( + "def all_gather(x: array, *, group: Optional[Group] = None) -> array"), + R"pbdoc( + Gather arrays from all processes. + + Gather the ``x`` arrays from all processes in the group and concatenate + them along the first axis. The arrays should all have the same shape. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + gather. If set to ``None`` the global group is used. Default: + ``None``. + + Returns: + array: The concatenation of all ``x`` arrays. + )pbdoc"); +} diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index 30bf67fa5..a261c1f88 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -1,4 +1,4 @@ -// Conbright © 2023-2024 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include @@ -18,6 +18,7 @@ void init_fft(nb::module_&); void init_linalg(nb::module_&); void init_constants(nb::module_&); void init_fast(nb::module_&); +void init_distributed(nb::module_&); NB_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -37,6 +38,7 @@ NB_MODULE(core, m) { init_linalg(m); init_constants(m); init_fast(m); + init_distributed(m); m.attr("__version__") = TOSTRING(_VERSION_); } diff --git a/python/tests/mpi_test_distributed.py b/python/tests/mpi_test_distributed.py new file mode 100644 index 000000000..3e0504b76 --- /dev/null +++ b/python/tests/mpi_test_distributed.py @@ -0,0 +1,98 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx +import mlx_tests + + +class TestDistributed(mlx_tests.MLXTestCase): + def test_groups(self): + world = mx.distributed.init() + self.assertEqual(world.size(), 8) + self.assertTrue(0 <= world.rank() < 8) + + world2 = mx.distributed.init() + self.assertEqual(world.size(), world2.size()) + self.assertEqual(world.rank(), world2.rank()) + + sub = world.split(world.rank() % 2) + self.assertEqual(sub.size(), 4) + self.assertEqual(sub.rank(), world.rank() // 2) + + sub = world.split(world.rank() // 2) + self.assertEqual(sub.size(), 2) + + def test_all_reduce(self): + world = mx.distributed.init() + dtypes = [ + mx.int8, + mx.uint8, + mx.int16, + mx.uint16, + mx.int32, + mx.uint32, + mx.float32, + mx.complex64, + ] + for dt in dtypes: + x = mx.ones((2, 2, 4), dtype=dt) + y = mx.distributed.all_reduce_sum(x) + self.assertTrue(mx.all(y == world.size())) + + sub = world.split(world.rank() % 2) + for dt in dtypes: + x = mx.ones((2, 2, 4), dtype=dt) + y = mx.distributed.all_reduce_sum(x, group=sub) + self.assertTrue(mx.all(y == sub.size())) + + def test_all_gather(self): + world = mx.distributed.init() + dtypes = [ + mx.int8, + mx.uint8, + mx.int16, + mx.uint16, + mx.int32, + mx.uint32, + mx.float32, + mx.complex64, + ] + for dt in dtypes: + x = mx.ones((2, 2, 4), dtype=dt) + y = mx.distributed.all_gather(x) + self.assertEqual(y.shape, (world.size() * 2, 2, 4)) + self.assertTrue(mx.all(y == 1)) + + sub = world.split(world.rank() % 2) + for dt in dtypes: + x = mx.ones((2, 2, 4), dtype=dt) + y = mx.distributed.all_gather(x, group=sub) + self.assertEqual(y.shape, (sub.size() * 2, 2, 4)) + self.assertTrue(mx.all(y == 1)) + + def test_mixed(self): + # Make the following groups: + # - world: 0 1 2 3 4 5 6 7 + # - sub_1: 0 1 0 1 0 1 0 1 + # - sub_2: 0 0 1 1 2 2 3 3 + # + # The corresponding colors to make them are + # - world: N/A + # - sub_1: 0 0 1 1 2 2 3 3 + # - sub_2: 0 1 0 1 0 1 0 1 + + world = mx.distributed.init() + sub_1 = world.split(world.rank() // 2) + sub_2 = world.split(world.rank() % 2) + + x = mx.ones((1, 8)) * world.rank() + y = mx.distributed.all_reduce_sum(x, group=sub_1) + z = mx.distributed.all_gather(y, group=sub_2) + z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True) + + self.assertTrue(mx.all(z == z_target)) + + +if __name__ == "__main__": + unittest.main()