Adds send/recv ops in distributed (#1366)

This commit is contained in:
Angelos Katharopoulos
2024-08-26 23:01:37 -07:00
committed by GitHub
parent 1d94ac3f90
commit cdb59faea6
13 changed files with 345 additions and 19 deletions

View File

@@ -4,6 +4,7 @@
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
@@ -121,4 +122,90 @@ void init_distributed(nb::module_& parent_module) {
Returns:
array: The concatenation of all ``x`` arrays.
)pbdoc");
m.def(
"send",
&distributed::send,
"x"_a,
"dst"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Send an array from the current process to the process that has rank
``dst`` in the group.
Args:
x (array): Input array.
dst (int): Rank of the destination process in the group.
group (Group): The group of processes that will participate in the
sned. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: An empty array which when evaluated the send is performed.
)pbdoc");
m.def(
"recv",
&distributed::recv,
"shape"_a,
"dtype"_a,
"src"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Recv an array with shape ``shape`` and dtype ``dtype`` from process
with rank ``src``.
Args:
shape (Tuple[int]): The shape of the array we are receiving.
dtype (Dtype): The data type of the array we are receiving.
src (int): Rank of the source process in the group.
group (Group): The group of processes that will participate in the
recv. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The array that was received from ``src``.
)pbdoc");
m.def(
"recv_like",
&distributed::recv_like,
"x"_a,
"src"_a,
nb::kw_only(),
"group"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Recv an array with shape and type like ``x`` from process with rank
``src``.
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
Args:
x (array): An array defining the shape and dtype of the array we are
receiving.
src (int): Rank of the source process in the group.
group (Group): The group of processes that will participate in the
recv. If set to ``None`` the global group is used. Default:
``None``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.
Returns:
array: The array that was received from ``src``.
)pbdoc");
}