* Start the communications branch using MPI
* Add ops and primitives
* Add python bindings for distributed
This commit is contained in:
Angelos Katharopoulos 2024-05-23 17:04:02 -07:00 committed by GitHub
parent 0189ab6ab6
commit 50dfb664db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 913 additions and 1 deletions

View File

@ -71,6 +71,7 @@ jobs:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@3.8 brew install python@3.8
brew install openmpi
python3.8 -m venv env python3.8 -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
@ -96,6 +97,7 @@ jobs:
source env/bin/activate source env/bin/activate
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
- run: - run:
name: Build example extension name: Build example extension
command: | command: |

View File

@ -167,6 +167,11 @@ else()
set(MLX_BUILD_ACCELERATE OFF) set(MLX_BUILD_ACCELERATE OFF)
endif() endif()
find_package(MPI)
if (MPI_FOUND)
target_include_directories(mlx PRIVATE ${MPI_INCLUDE_PATH})
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
target_include_directories( target_include_directories(

View File

@ -9,3 +9,4 @@ build_example(tutorial.cpp)
build_example(linear_regression.cpp) build_example(linear_regression.cpp)
build_example(logistic_regression.cpp) build_example(logistic_regression.cpp)
build_example(metal_capture.cpp) build_example(metal_capture.cpp)
build_example(distributed.cpp)

View File

@ -0,0 +1,22 @@
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "mlx/mlx.h"
using namespace mlx::core;
int main() {
if (!distributed::is_available()) {
std::cout << "No communication backend found" << std::endl;
return 1;
}
auto global_group = distributed::init();
std::cout << global_group.rank() << " / " << global_group.size() << std::endl;
array x = ones({10});
array out = distributed::all_reduce_sum(x, global_group);
std::cout << out << std::endl;
}

View File

@ -25,6 +25,7 @@ else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_cpu)
endif() endif()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/distributed)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)
if (MLX_BUILD_ACCELERATE) if (MLX_BUILD_ACCELERATE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/accelerate)

View File

@ -0,0 +1,16 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
)
if (MPI_FOUND AND MLX_BUILD_CPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
else()
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/no_distributed.cpp
)
endif()

View File

@ -0,0 +1,62 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <memory>
#include "mlx/array.h"
namespace mlx::core::distributed {
/* Check if a communication backend is available */
bool is_available();
/**
* A distributed::Group represents a group of independent mlx processes that
* can communicate. We must also be able to create sub-groups from a group in
* order to define more granular communication.
*/
struct Group {
Group(std::shared_ptr<void> group) : group_(group) {}
int rank();
int size();
/**
* Split the group according to the provided color. Namely processes that use
* the same color will go to the same group.
*
* The key defines the rank of the processes in the new group. The smaller
* the key the smaller the rank. If the provided key is negative, then the
* rank in the current group is used.
*/
Group split(int color, int key = -1);
const std::shared_ptr<void>& raw_group() {
return group_;
}
private:
std::shared_ptr<void> group_{nullptr};
};
/**
* Initialize the distributed backend and return the group containing all
* discoverable processes.
*/
Group init();
namespace detail {
/* Return the communication stream. */
Stream communication_stream();
/* Perform an all reduce sum operation */
void all_reduce_sum(Group group, const array& input, array& output);
/* Perform an all reduce sum operation */
void all_gather(Group group, const array& input, array& output);
} // namespace detail
} // namespace mlx::core::distributed

View File

@ -0,0 +1,5 @@
target_sources(
mlx
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/mpi.cpp
)

283
mlx/distributed/mpi/mpi.cpp Normal file
View File

@ -0,0 +1,283 @@
// Copyright © 2024 Apple Inc.
#include <dlfcn.h>
#include <mpi.h>
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/distributed.h"
#include "mlx/scheduler.h"
#define LOAD_SYMBOL(symbol, variable) \
{ \
variable = (decltype(variable))dlsym(libmpi_handle_, #symbol); \
char* error = dlerror(); \
if (error != nullptr) { \
libmpi_handle_ = nullptr; \
return; \
} \
}
namespace mlx::core::distributed {
namespace {
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
struct MPIWrapper {
MPIWrapper() {
libmpi_handle_ = dlopen("libmpi.dylib", RTLD_NOW | RTLD_GLOBAL);
if (libmpi_handle_ == nullptr) {
return;
}
// API
LOAD_SYMBOL(MPI_Init, init);
LOAD_SYMBOL(MPI_Finalize, finalize);
LOAD_SYMBOL(MPI_Comm_rank, rank);
LOAD_SYMBOL(MPI_Comm_size, size);
LOAD_SYMBOL(MPI_Comm_split, comm_split);
LOAD_SYMBOL(MPI_Comm_free, comm_free);
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
LOAD_SYMBOL(MPI_Allgather, all_gather);
// Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
// Ops
LOAD_SYMBOL(ompi_mpi_op_sum, op_sum_);
// Datatypes
LOAD_SYMBOL(ompi_mpi_c_bool, mpi_bool_);
LOAD_SYMBOL(ompi_mpi_int8_t, mpi_int8_);
LOAD_SYMBOL(ompi_mpi_uint8_t, mpi_uint8_);
LOAD_SYMBOL(ompi_mpi_int16_t, mpi_int16_);
LOAD_SYMBOL(ompi_mpi_uint16_t, mpi_uint16_);
LOAD_SYMBOL(ompi_mpi_int32_t, mpi_int32_);
LOAD_SYMBOL(ompi_mpi_uint32_t, mpi_uint32_);
LOAD_SYMBOL(ompi_mpi_int64_t, mpi_int64_);
LOAD_SYMBOL(ompi_mpi_uint64_t, mpi_uint64_);
LOAD_SYMBOL(ompi_mpi_float, mpi_float_);
LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_);
}
bool is_available() {
return libmpi_handle_ != nullptr;
}
bool init_safe() {
if (!is_available()) {
return false;
}
return init(nullptr, nullptr) == MPI_SUCCESS;
}
void finalize_safe() {
if (is_available()) {
finalize();
}
}
MPI_Comm world() {
return comm_world_;
}
MPI_Datatype datatype(const array& arr) {
switch (arr.dtype()) {
case bool_:
return mpi_bool_;
case int8:
return mpi_int8_;
case uint8:
return mpi_uint8_;
case int16:
return mpi_int16_;
case uint16:
return mpi_uint16_;
case int32:
return mpi_int32_;
case uint32:
return mpi_uint32_;
case int64:
return mpi_int64_;
case uint64:
return mpi_uint64_;
case float32:
return mpi_float_;
case complex64:
return mpi_complex_;
case float16:
case bfloat16:
throw std::runtime_error("MPI doesn't support 16-bit floats");
}
}
MPI_Op op_sum() {
return op_sum_;
}
void* libmpi_handle_;
// API
int (*init)(int*, char***);
int (*finalize)();
int (*rank)(MPI_Comm, int*);
int (*size)(MPI_Comm, int*);
int (*all_reduce)(const void*, void*, int, MPI_Datatype, MPI_Op, MPI_Comm);
int (*all_gather)(
const void*,
int,
MPI_Datatype,
void*,
int,
MPI_Datatype,
MPI_Comm);
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
int (*comm_free)(MPI_Comm*);
// Objects
MPI_Comm comm_world_;
// Ops
MPI_Op op_sum_;
// Datatypes
MPI_Datatype mpi_bool_;
MPI_Datatype mpi_int8_;
MPI_Datatype mpi_uint8_;
MPI_Datatype mpi_int16_;
MPI_Datatype mpi_uint16_;
MPI_Datatype mpi_int32_;
MPI_Datatype mpi_uint32_;
MPI_Datatype mpi_int64_;
MPI_Datatype mpi_uint64_;
MPI_Datatype mpi_float_;
MPI_Datatype mpi_complex_;
};
MPIWrapper& mpi() {
static MPIWrapper wrapper;
return wrapper;
}
struct MPIGroupImpl {
MPIGroupImpl(MPI_Comm comm, bool global)
: comm_(comm), global_(global), rank_(-1), size_(-1) {}
~MPIGroupImpl() {
if (global_) {
mpi().finalize_safe();
} else {
mpi().comm_free(&comm_);
}
}
MPI_Comm comm() {
return comm_;
}
int rank() {
if (rank_ < 0) {
mpi().rank(comm_, &rank_);
}
return rank_;
}
int size() {
if (size_ < 0) {
mpi().size(comm_, &size_);
}
return size_;
}
private:
MPI_Comm comm_;
bool global_;
int rank_;
int size_;
};
MPI_Comm to_comm(Group& group) {
return std::static_pointer_cast<MPIGroupImpl>(group.raw_group())->comm();
}
} // namespace
int Group::rank() {
return std::static_pointer_cast<MPIGroupImpl>(group_)->rank();
}
int Group::size() {
return std::static_pointer_cast<MPIGroupImpl>(group_)->size();
}
Group Group::split(int color, int key) {
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
key = (key < 0) ? rank() : key;
MPI_Comm new_comm;
int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm);
if (result != MPI_SUCCESS) {
throw std::runtime_error("MPI could not split this group");
}
return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
}
bool is_available() {
return mpi().is_available();
}
Group init() {
static std::shared_ptr<MPIGroupImpl> global_group = nullptr;
if (global_group == nullptr) {
if (!mpi().init_safe()) {
throw std::runtime_error("Cannot initialize MPI");
}
global_group = std::make_shared<MPIGroupImpl>(mpi().world(), true);
}
return Group(global_group);
}
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
void all_reduce_sum(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
mpi().all_reduce(
input.data<void>(),
output.data<void>(),
input.size(),
mpi().datatype(input),
mpi().op_sum(),
to_comm(group));
}
void all_gather(Group group, const array& input_, array& output) {
array input = ensure_row_contiguous(input_);
mpi().all_gather(
input.data<void>(),
input.size(),
mpi().datatype(input),
output.data<void>(),
input.size(),
mpi().datatype(output),
to_comm(group));
}
} // namespace detail
} // namespace mlx::core::distributed

View File

@ -0,0 +1,39 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed {
int Group::rank() {
return 0;
}
int Group::size() {
return 1;
}
Group Group::split(int color, int key) {
throw std::runtime_error("Cannot split the distributed group further");
}
bool is_available() {
return false;
}
Group init() {
return Group(nullptr);
}
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
}
void all_reduce_sum(Group group, const array& input, array& output) {}
void all_gather(Group group, const array& input, array& output) {}
} // namespace detail
} // namespace mlx::core::distributed

54
mlx/distributed/ops.cpp Normal file
View File

@ -0,0 +1,54 @@
// Copyright © 2024 Apple Inc.
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
namespace mlx::core::distributed {
namespace {
Group to_group(std::optional<Group> group) {
if (group.has_value()) {
return group.value();
} else {
return distributed::init();
}
}
} // namespace
array all_reduce_sum(const array& x, std::optional<Group> group_) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(group, AllReduce::Sum),
{x});
}
array all_gather(const array& x, std::optional<Group> group_) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
auto result_shape = x.shape();
if (result_shape.size() == 0) {
result_shape.push_back(group.size());
} else {
result_shape[0] *= group.size();
}
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<AllGather>(group),
{x});
}
} // namespace mlx::core::distributed

14
mlx/distributed/ops.h Normal file
View File

@ -0,0 +1,14 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <optional>
#include "mlx/distributed/distributed.h"
namespace mlx::core::distributed {
array all_reduce_sum(const array& x, std::optional<Group> group = std::nullopt);
array all_gather(const array& x, std::optional<Group> group = std::nullopt);
} // namespace mlx::core::distributed

View File

@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/ops.h"
namespace mlx::core::distributed {
void AllReduce::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
switch (reduce_type_) {
case Sum:
distributed::detail::all_reduce_sum(group(), inputs[0], outputs[0]);
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
}
std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
switch (reduce_type_) {
case Sum:
return {{all_reduce_sum(inputs[0], group())}, axes};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
}
std::vector<array> AllReduce::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
switch (reduce_type_) {
case Sum:
return {all_reduce_sum(tangents[0], group())};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
}
std::vector<array> AllReduce::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
return cotangents;
}
void AllGather::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::all_gather(group(), inputs[0], outputs[0]);
}
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{all_gather(inputs[0], group())}, axes};
}
std::vector<array> AllGather::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
return {all_gather(tangents[0], group())};
}
std::vector<array> AllGather::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
auto g = group();
std::vector<int> starts(primals[0].ndim(), 0);
auto stops = primals[0].shape();
starts[0] = g.rank() * stops[0];
stops[0] += starts[0];
return {slice(cotangents[0], starts, stops)};
}
} // namespace mlx::core::distributed

View File

@ -0,0 +1,100 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "mlx/distributed/distributed.h"
#include "mlx/primitives.h"
namespace mlx::core::distributed {
class DistPrimitive : public Primitive {
public:
DistPrimitive(Group group)
: Primitive(detail::communication_stream()), group_(group) {}
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error(
"Communication primitives cannot be run on the GPU");
}
const Group& group() const {
return group_;
}
private:
Group group_;
};
class AllReduce : public DistPrimitive {
public:
enum ReduceType { And, Or, Sum, Prod, Min, Max };
AllReduce(Group group, ReduceType reduce_type)
: DistPrimitive(group), reduce_type_(reduce_type) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
void print(std::ostream& os) override {
switch (reduce_type_) {
case And:
os << "And";
case Or:
os << "And";
break;
case Sum:
os << "Sum";
break;
case Prod:
os << "Prod";
break;
case Min:
os << "Min";
break;
case Max:
os << "Max";
break;
}
os << " AllReduce";
}
private:
ReduceType reduce_type_;
};
class AllGather : public DistPrimitive {
public:
AllGather(Group group) : DistPrimitive(group) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
std::vector<array> jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) override;
std::vector<array> vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs) override;
DEFINE_PRINT(AllGather);
};
} // namespace mlx::core::distributed

View File

@ -6,6 +6,8 @@
#include "mlx/backend/metal/metal.h" #include "mlx/backend/metal/metal.h"
#include "mlx/compile.h" #include "mlx/compile.h"
#include "mlx/device.h" #include "mlx/device.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
#include "mlx/fast.h" #include "mlx/fast.h"
#include "mlx/fft.h" #include "mlx/fft.h"
#include "mlx/io.h" #include "mlx/io.h"

View File

@ -6,6 +6,7 @@ nanobind_add_module(
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp ${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp

107
python/src/distributed.cpp Normal file
View File

@ -0,0 +1,107 @@
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlx::core;
void init_distributed(nb::module_& parent_module) {
auto m = parent_module.def_submodule(
"distributed", "mlx.core.distributed: Communication operations");
nb::class_<distributed::Group>(
m,
"Group",
R"pbcopy(
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
processes that can communicate.
)pbcopy")
.def("rank", &distributed::Group::rank, "Get the rank of this process")
.def("size", &distributed::Group::size, "Get the size of the group")
.def(
"split",
&distributed::Group::split,
"color"_a,
"key"_a = -1,
nb::sig("def split(self, color: int, key: int = -1) -> Group"),
R"pbdoc(
Split the group to subgroups based on the provided color.
Processes that use the same color go to the same group. The ``key``
argument defines the rank in the new group. The smaller the key the
smaller the rank. If the key is negative then the rank in the
current group is used.
Args:
color (int): A value to group processes into subgroups.
key (int, optional): A key to optionally change the rank ordering
of the processes.
)pbdoc");
m.def(
"is_available",
&distributed::is_available,
R"pbdoc(
Check if a communication backend is available.
)pbdoc");
m.def(
"init",
&distributed::init,
R"pbdoc(
Initialize the communication backend and create the global communication group.
)pbdoc");
m.def(
"all_reduce_sum",
&distributed::all_reduce_sum,
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
nb::sig(
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
R"pbdoc(
All reduce sum.
Sum the ``x`` arrays from all processes in the group.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
reduction. If set to ``None`` the global group is used. Default:
``None``.
Returns:
array: The sum of all ``x`` arrays.
)pbdoc");
m.def(
"all_gather",
&distributed::all_gather,
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
nb::sig(
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
R"pbdoc(
Gather arrays from all processes.
Gather the ``x`` arrays from all processes in the group and concatenate
them along the first axis. The arrays should all have the same shape.
Args:
x (array): Input array.
group (Group): The group of processes that will participate in the
gather. If set to ``None`` the global group is used. Default:
``None``.
Returns:
array: The concatenation of all ``x`` arrays.
)pbdoc");
}

View File

@ -1,4 +1,4 @@
// Conbright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <nanobind/nanobind.h> #include <nanobind/nanobind.h>
@ -18,6 +18,7 @@ void init_fft(nb::module_&);
void init_linalg(nb::module_&); void init_linalg(nb::module_&);
void init_constants(nb::module_&); void init_constants(nb::module_&);
void init_fast(nb::module_&); void init_fast(nb::module_&);
void init_distributed(nb::module_&);
NB_MODULE(core, m) { NB_MODULE(core, m) {
m.doc() = "mlx: A framework for machine learning on Apple silicon."; m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@ -37,6 +38,7 @@ NB_MODULE(core, m) {
init_linalg(m); init_linalg(m);
init_constants(m); init_constants(m);
init_fast(m); init_fast(m);
init_distributed(m);
m.attr("__version__") = TOSTRING(_VERSION_); m.attr("__version__") = TOSTRING(_VERSION_);
} }

View File

@ -0,0 +1,98 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestDistributed(mlx_tests.MLXTestCase):
def test_groups(self):
world = mx.distributed.init()
self.assertEqual(world.size(), 8)
self.assertTrue(0 <= world.rank() < 8)
world2 = mx.distributed.init()
self.assertEqual(world.size(), world2.size())
self.assertEqual(world.rank(), world2.rank())
sub = world.split(world.rank() % 2)
self.assertEqual(sub.size(), 4)
self.assertEqual(sub.rank(), world.rank() // 2)
sub = world.split(world.rank() // 2)
self.assertEqual(sub.size(), 2)
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.complex64,
]
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_reduce_sum(x)
self.assertTrue(mx.all(y == world.size()))
sub = world.split(world.rank() % 2)
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_reduce_sum(x, group=sub)
self.assertTrue(mx.all(y == sub.size()))
def test_all_gather(self):
world = mx.distributed.init()
dtypes = [
mx.int8,
mx.uint8,
mx.int16,
mx.uint16,
mx.int32,
mx.uint32,
mx.float32,
mx.complex64,
]
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_gather(x)
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
self.assertTrue(mx.all(y == 1))
sub = world.split(world.rank() % 2)
for dt in dtypes:
x = mx.ones((2, 2, 4), dtype=dt)
y = mx.distributed.all_gather(x, group=sub)
self.assertEqual(y.shape, (sub.size() * 2, 2, 4))
self.assertTrue(mx.all(y == 1))
def test_mixed(self):
# Make the following groups:
# - world: 0 1 2 3 4 5 6 7
# - sub_1: 0 1 0 1 0 1 0 1
# - sub_2: 0 0 1 1 2 2 3 3
#
# The corresponding colors to make them are
# - world: N/A
# - sub_1: 0 0 1 1 2 2 3 3
# - sub_2: 0 1 0 1 0 1 0 1
world = mx.distributed.init()
sub_1 = world.split(world.rank() // 2)
sub_2 = world.split(world.rank() % 2)
x = mx.ones((1, 8)) * world.rank()
y = mx.distributed.all_reduce_sum(x, group=sub_1)
z = mx.distributed.all_gather(y, group=sub_2)
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)
self.assertTrue(mx.all(z == z_target))
if __name__ == "__main__":
unittest.main()