mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 16:48:10 +08:00
Comms (#1097)
* Start the communications branch using MPI * Add ops and primitives * Add python bindings for distributed
This commit is contained in:

committed by
GitHub

parent
0189ab6ab6
commit
50dfb664db
@@ -6,6 +6,7 @@ nanobind_add_module(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/array.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/convert.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||
|
107
python/src/distributed.cpp
Normal file
107
python/src/distributed.cpp
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlx::core;
|
||||
|
||||
void init_distributed(nb::module_& parent_module) {
|
||||
auto m = parent_module.def_submodule(
|
||||
"distributed", "mlx.core.distributed: Communication operations");
|
||||
|
||||
nb::class_<distributed::Group>(
|
||||
m,
|
||||
"Group",
|
||||
R"pbcopy(
|
||||
An :class:`mlx.core.distributed.Group` represents a group of independent mlx
|
||||
processes that can communicate.
|
||||
)pbcopy")
|
||||
.def("rank", &distributed::Group::rank, "Get the rank of this process")
|
||||
.def("size", &distributed::Group::size, "Get the size of the group")
|
||||
.def(
|
||||
"split",
|
||||
&distributed::Group::split,
|
||||
"color"_a,
|
||||
"key"_a = -1,
|
||||
nb::sig("def split(self, color: int, key: int = -1) -> Group"),
|
||||
R"pbdoc(
|
||||
Split the group to subgroups based on the provided color.
|
||||
|
||||
Processes that use the same color go to the same group. The ``key``
|
||||
argument defines the rank in the new group. The smaller the key the
|
||||
smaller the rank. If the key is negative then the rank in the
|
||||
current group is used.
|
||||
|
||||
Args:
|
||||
color (int): A value to group processes into subgroups.
|
||||
key (int, optional): A key to optionally change the rank ordering
|
||||
of the processes.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"is_available",
|
||||
&distributed::is_available,
|
||||
R"pbdoc(
|
||||
Check if a communication backend is available.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"init",
|
||||
&distributed::init,
|
||||
R"pbdoc(
|
||||
Initialize the communication backend and create the global communication group.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"all_reduce_sum",
|
||||
&distributed::all_reduce_sum,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_reduce_sum(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
All reduce sum.
|
||||
|
||||
Sum the ``x`` arrays from all processes in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
reduction. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
|
||||
Returns:
|
||||
array: The sum of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"all_gather",
|
||||
&distributed::all_gather,
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def all_gather(x: array, *, group: Optional[Group] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Gather arrays from all processes.
|
||||
|
||||
Gather the ``x`` arrays from all processes in the group and concatenate
|
||||
them along the first axis. The arrays should all have the same shape.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
group (Group): The group of processes that will participate in the
|
||||
gather. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
|
||||
Returns:
|
||||
array: The concatenation of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
// Conbright © 2023-2024 Apple Inc.
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
@@ -18,6 +18,7 @@ void init_fft(nb::module_&);
|
||||
void init_linalg(nb::module_&);
|
||||
void init_constants(nb::module_&);
|
||||
void init_fast(nb::module_&);
|
||||
void init_distributed(nb::module_&);
|
||||
|
||||
NB_MODULE(core, m) {
|
||||
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
|
||||
@@ -37,6 +38,7 @@ NB_MODULE(core, m) {
|
||||
init_linalg(m);
|
||||
init_constants(m);
|
||||
init_fast(m);
|
||||
init_distributed(m);
|
||||
|
||||
m.attr("__version__") = TOSTRING(_VERSION_);
|
||||
}
|
||||
|
98
python/tests/mpi_test_distributed.py
Normal file
98
python/tests/mpi_test_distributed.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
|
||||
|
||||
class TestDistributed(mlx_tests.MLXTestCase):
|
||||
def test_groups(self):
|
||||
world = mx.distributed.init()
|
||||
self.assertEqual(world.size(), 8)
|
||||
self.assertTrue(0 <= world.rank() < 8)
|
||||
|
||||
world2 = mx.distributed.init()
|
||||
self.assertEqual(world.size(), world2.size())
|
||||
self.assertEqual(world.rank(), world2.rank())
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
self.assertEqual(sub.size(), 4)
|
||||
self.assertEqual(sub.rank(), world.rank() // 2)
|
||||
|
||||
sub = world.split(world.rank() // 2)
|
||||
self.assertEqual(sub.size(), 2)
|
||||
|
||||
def test_all_reduce(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x)
|
||||
self.assertTrue(mx.all(y == world.size()))
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub)
|
||||
self.assertTrue(mx.all(y == sub.size()))
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int16,
|
||||
mx.uint16,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.complex64,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x)
|
||||
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
sub = world.split(world.rank() % 2)
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x, group=sub)
|
||||
self.assertEqual(y.shape, (sub.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
def test_mixed(self):
|
||||
# Make the following groups:
|
||||
# - world: 0 1 2 3 4 5 6 7
|
||||
# - sub_1: 0 1 0 1 0 1 0 1
|
||||
# - sub_2: 0 0 1 1 2 2 3 3
|
||||
#
|
||||
# The corresponding colors to make them are
|
||||
# - world: N/A
|
||||
# - sub_1: 0 0 1 1 2 2 3 3
|
||||
# - sub_2: 0 1 0 1 0 1 0 1
|
||||
|
||||
world = mx.distributed.init()
|
||||
sub_1 = world.split(world.rank() // 2)
|
||||
sub_2 = world.split(world.rank() % 2)
|
||||
|
||||
x = mx.ones((1, 8)) * world.rank()
|
||||
y = mx.distributed.all_reduce_sum(x, group=sub_1)
|
||||
z = mx.distributed.all_gather(y, group=sub_2)
|
||||
z_target = mx.arange(8).reshape(4, 2).sum(-1, keepdims=True)
|
||||
|
||||
self.assertTrue(mx.all(z == z_target))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user