mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Adds send/recv ops in distributed (#1366)
This commit is contained in:

committed by
GitHub

parent
1d94ac3f90
commit
cdb59faea6
@@ -4,6 +4,7 @@
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
@@ -121,4 +122,90 @@ void init_distributed(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The concatenation of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"send",
|
||||
&distributed::send,
|
||||
"x"_a,
|
||||
"dst"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Send an array from the current process to the process that has rank
|
||||
``dst`` in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
dst (int): Rank of the destination process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
sned. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: An empty array which when evaluated the send is performed.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"recv",
|
||||
&distributed::recv,
|
||||
"shape"_a,
|
||||
"dtype"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Recv an array with shape ``shape`` and dtype ``dtype`` from process
|
||||
with rank ``src``.
|
||||
|
||||
Args:
|
||||
shape (Tuple[int]): The shape of the array we are receiving.
|
||||
dtype (Dtype): The data type of the array we are receiving.
|
||||
src (int): Rank of the source process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
recv. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The array that was received from ``src``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"recv_like",
|
||||
&distributed::recv_like,
|
||||
"x"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Recv an array with shape and type like ``x`` from process with rank
|
||||
``src``.
|
||||
|
||||
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
|
||||
|
||||
Args:
|
||||
x (array): An array defining the shape and dtype of the array we are
|
||||
receiving.
|
||||
src (int): Rank of the source process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
recv. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The array that was received from ``src``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
@@ -93,6 +93,23 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertTrue(mx.all(z == z_target))
|
||||
|
||||
def test_send_recv(self):
|
||||
world = mx.distributed.init()
|
||||
pairs = world.split(world.rank() // 2)
|
||||
neighbor = (pairs.rank() + 1) % 2
|
||||
send = pairs.rank() == 0
|
||||
|
||||
x = mx.ones(10)
|
||||
for i in range(10):
|
||||
if send:
|
||||
mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs))
|
||||
else:
|
||||
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||
mx.eval(x)
|
||||
send = not send
|
||||
|
||||
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user