mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 03:18:12 +08:00
Adds send/recv ops in distributed (#1366)
This commit is contained in:
committed by
GitHub
parent
1d94ac3f90
commit
cdb59faea6
@@ -4,6 +4,7 @@
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
#include <nanobind/stl/variant.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
@@ -121,4 +122,90 @@ void init_distributed(nb::module_& parent_module) {
|
||||
Returns:
|
||||
array: The concatenation of all ``x`` arrays.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"send",
|
||||
&distributed::send,
|
||||
"x"_a,
|
||||
"dst"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Send an array from the current process to the process that has rank
|
||||
``dst`` in the group.
|
||||
|
||||
Args:
|
||||
x (array): Input array.
|
||||
dst (int): Rank of the destination process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
sned. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: An empty array which when evaluated the send is performed.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"recv",
|
||||
&distributed::recv,
|
||||
"shape"_a,
|
||||
"dtype"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Recv an array with shape ``shape`` and dtype ``dtype`` from process
|
||||
with rank ``src``.
|
||||
|
||||
Args:
|
||||
shape (Tuple[int]): The shape of the array we are receiving.
|
||||
dtype (Dtype): The data type of the array we are receiving.
|
||||
src (int): Rank of the source process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
recv. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The array that was received from ``src``.
|
||||
)pbdoc");
|
||||
|
||||
m.def(
|
||||
"recv_like",
|
||||
&distributed::recv_like,
|
||||
"x"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Recv an array with shape and type like ``x`` from process with rank
|
||||
``src``.
|
||||
|
||||
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
|
||||
|
||||
Args:
|
||||
x (array): An array defining the shape and dtype of the array we are
|
||||
receiving.
|
||||
src (int): Rank of the source process in the group.
|
||||
group (Group): The group of processes that will participate in the
|
||||
recv. If set to ``None`` the global group is used. Default:
|
||||
``None``.
|
||||
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||
in which case the default stream of the default device is used.
|
||||
|
||||
Returns:
|
||||
array: The array that was received from ``src``.
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user