Adds send/recv ops in distributed (#1366)

This commit is contained in:
Angelos Katharopoulos
2024-08-26 23:01:37 -07:00
committed by GitHub
parent 1d94ac3f90
commit cdb59faea6
13 changed files with 345 additions and 19 deletions

View File

@@ -4,6 +4,7 @@
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
@@ -121,4 +122,90 @@ void init_distributed(nb::module_& parent_module) {
Returns:
array: The concatenation of all ``x`` arrays.
)pbdoc");
m.def(
"send",
&distributed::send,
"x"_a,
"dst"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Send an array from the current process to the process that has rank
``dst`` in the group.
Args:
x (array): Input array.
dst (int): Rank of the destination process in the group.
group (Group): The group of processes that will participate in the
sned. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: An empty array which when evaluated the send is performed.
)pbdoc");
m.def(
"recv",
&distributed::recv,
"shape"_a,
"dtype"_a,
"src"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Recv an array with shape ``shape`` and dtype ``dtype`` from process
with rank ``src``.
Args:
shape (Tuple[int]): The shape of the array we are receiving.
dtype (Dtype): The data type of the array we are receiving.
src (int): Rank of the source process in the group.
group (Group): The group of processes that will participate in the
recv. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The array that was received from ``src``.
)pbdoc");
m.def(
"recv_like",
&distributed::recv_like,
"x"_a,
"src"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Recv an array with shape and type like ``x`` from process with rank
``src``.
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
Args:
x (array): An array defining the shape and dtype of the array we are
receiving.
src (int): Rank of the source process in the group.
group (Group): The group of processes that will participate in the
recv. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The array that was received from ``src``.
)pbdoc");
}

View File

@@ -93,6 +93,23 @@ class TestDistributed(mlx_tests.MLXTestCase):
self.assertTrue(mx.all(z == z_target))
def test_send_recv(self):
world = mx.distributed.init()
pairs = world.split(world.rank() // 2)
neighbor = (pairs.rank() + 1) % 2
send = pairs.rank() == 0
x = mx.ones(10)
for i in range(10):
if send:
mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs))
else:
x = mx.distributed.recv_like(x, neighbor, group=pairs)
mx.eval(x)
send = not send
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
if __name__ == "__main__":
unittest.main()