mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 20:07:59 +08:00
Adds send/recv ops in distributed (#1366)
This commit is contained in:
parent
1d94ac3f90
commit
cdb59faea6
@ -17,3 +17,6 @@ made available.
|
|||||||
init
|
init
|
||||||
all_sum
|
all_sum
|
||||||
all_gather
|
all_gather
|
||||||
|
send
|
||||||
|
recv
|
||||||
|
recv_like
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
namespace mlx::core::distributed {
|
namespace mlx::core::distributed {
|
||||||
|
|
||||||
void signal_and_wait(const array& in, const array& out, const Stream s) {
|
void signal_and_wait(const array& in, const array& out, const Stream& s) {
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
d.end_encoding(s.index);
|
d.end_encoding(s.index);
|
||||||
auto command_buffer = d.get_command_buffer(s.index);
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
@ -81,4 +81,62 @@ void AllGather::eval_gpu(
|
|||||||
signal_and_wait(in, out, stream());
|
signal_and_wait(in, out, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Send::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& in = inputs[0];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
// Schedule an async send on the comm stream
|
||||||
|
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
||||||
|
if (in.event().valid()) {
|
||||||
|
in.event().wait();
|
||||||
|
}
|
||||||
|
distributed::detail::send(group, in, dst);
|
||||||
|
out.event().signal();
|
||||||
|
};
|
||||||
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
|
||||||
|
// Encode a signal event for the input but not a wait since we don't need to
|
||||||
|
// wait on the output.
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
d.end_encoding(s.index);
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
if (in.event().valid()) {
|
||||||
|
command_buffer->encodeSignalEvent(
|
||||||
|
static_cast<MTL::Event*>(in.event().raw_event().get()),
|
||||||
|
in.event().value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Recv::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
auto& out = outputs[0];
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
// Schedule an async recv on the comm stream
|
||||||
|
auto task = [out = out, group = group(), src = src_]() mutable {
|
||||||
|
distributed::detail::recv(group, out, src);
|
||||||
|
out.event().signal();
|
||||||
|
};
|
||||||
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
|
||||||
|
// Encode a wait event as there is no input for the recv to encode a signal.
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto command_buffer = d.get_command_buffer(s.index);
|
||||||
|
command_buffer->encodeWait(
|
||||||
|
static_cast<MTL::Event*>(out.event().raw_event().get()),
|
||||||
|
out.event().value());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -126,6 +126,8 @@ NO_GPU_MULTI(CustomKernel)
|
|||||||
namespace distributed {
|
namespace distributed {
|
||||||
NO_GPU_MULTI(AllReduce)
|
NO_GPU_MULTI(AllReduce)
|
||||||
NO_GPU_MULTI(AllGather)
|
NO_GPU_MULTI(AllGather)
|
||||||
|
NO_GPU_MULTI(Send)
|
||||||
|
NO_GPU_MULTI(Recv)
|
||||||
} // namespace distributed
|
} // namespace distributed
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -50,17 +50,4 @@ struct Group {
|
|||||||
*/
|
*/
|
||||||
Group init(bool strict = false);
|
Group init(bool strict = false);
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
/* Return the communication stream. */
|
|
||||||
Stream communication_stream();
|
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
|
||||||
void all_sum(Group group, const array& input, array& output);
|
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
|
||||||
void all_gather(Group group, const array& input, array& output);
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -12,7 +12,13 @@ Stream communication_stream();
|
|||||||
/* Perform an all reduce sum operation */
|
/* Perform an all reduce sum operation */
|
||||||
void all_sum(Group group, const array& input, array& output);
|
void all_sum(Group group, const array& input, array& output);
|
||||||
|
|
||||||
/* Perform an all reduce sum operation */
|
/* Perform an all gather operation */
|
||||||
void all_gather(Group group, const array& input, array& output);
|
void all_gather(Group group, const array& input, array& output);
|
||||||
|
|
||||||
|
/** Send an array to the dst rank */
|
||||||
|
void send(Group group, const array& input, int dst);
|
||||||
|
|
||||||
|
/** Recv an array from the src rank */
|
||||||
|
void recv(Group group, array& out, int src);
|
||||||
|
|
||||||
} // namespace mlx::core::distributed::detail
|
} // namespace mlx::core::distributed::detail
|
||||||
|
@ -48,6 +48,8 @@ struct MPIWrapper {
|
|||||||
LOAD_SYMBOL(MPI_Comm_free, comm_free);
|
LOAD_SYMBOL(MPI_Comm_free, comm_free);
|
||||||
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
|
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
|
||||||
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
LOAD_SYMBOL(MPI_Allgather, all_gather);
|
||||||
|
LOAD_SYMBOL(MPI_Send, send);
|
||||||
|
LOAD_SYMBOL(MPI_Recv, recv);
|
||||||
|
|
||||||
// Objects
|
// Objects
|
||||||
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
|
||||||
@ -142,6 +144,8 @@ struct MPIWrapper {
|
|||||||
MPI_Comm);
|
MPI_Comm);
|
||||||
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
|
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
|
||||||
int (*comm_free)(MPI_Comm*);
|
int (*comm_free)(MPI_Comm*);
|
||||||
|
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
|
||||||
|
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
|
||||||
|
|
||||||
// Objects
|
// Objects
|
||||||
MPI_Comm comm_world_;
|
MPI_Comm comm_world_;
|
||||||
@ -285,6 +289,29 @@ void all_gather(Group group, const array& input_, array& output) {
|
|||||||
to_comm(group));
|
to_comm(group));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void send(Group group, const array& input_, int dst) {
|
||||||
|
array input = ensure_row_contiguous(input_);
|
||||||
|
mpi().send(
|
||||||
|
input.data<void>(),
|
||||||
|
input.size(),
|
||||||
|
mpi().datatype(input),
|
||||||
|
dst,
|
||||||
|
0,
|
||||||
|
to_comm(group));
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(Group group, array& out, int src) {
|
||||||
|
MPI_Status status;
|
||||||
|
mpi().recv(
|
||||||
|
out.data<void>(),
|
||||||
|
out.size(),
|
||||||
|
mpi().datatype(out),
|
||||||
|
src,
|
||||||
|
MPI_ANY_TAG,
|
||||||
|
to_comm(group),
|
||||||
|
&status);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -34,6 +34,8 @@ Stream communication_stream() {
|
|||||||
|
|
||||||
void all_sum(Group group, const array& input, array& output) {}
|
void all_sum(Group group, const array& input, array& output) {}
|
||||||
void all_gather(Group group, const array& input, array& output) {}
|
void all_gather(Group group, const array& input, array& output) {}
|
||||||
|
void send(Group group, const array& input, int dst) {}
|
||||||
|
void recv(Group group, array& out, int src) {}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
|
|
||||||
@ -57,4 +59,59 @@ array all_gather(
|
|||||||
{x});
|
{x});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array send(
|
||||||
|
const array& x,
|
||||||
|
int dst,
|
||||||
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto group = to_group(group_);
|
||||||
|
|
||||||
|
if (group.size() == 1) {
|
||||||
|
throw std::invalid_argument("Cannot send to a singleton group");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst < 0 || dst >= group.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Invalid destination=" << dst << " for a group of size "
|
||||||
|
<< group.size();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return array(
|
||||||
|
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
|
||||||
|
}
|
||||||
|
|
||||||
|
array recv(
|
||||||
|
std::vector<int> shape,
|
||||||
|
Dtype dtype,
|
||||||
|
int src,
|
||||||
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
auto group = to_group(group_);
|
||||||
|
|
||||||
|
if (group.size() == 1) {
|
||||||
|
throw std::invalid_argument("Cannot recv from a singleton group");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src < 0 || src >= group.size()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Invalid source=" << src << " for a group of size " << group.size();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return array(
|
||||||
|
std::move(shape),
|
||||||
|
std::move(dtype),
|
||||||
|
std::make_shared<Recv>(to_stream(s), group, src),
|
||||||
|
std::vector<array>{});
|
||||||
|
}
|
||||||
|
|
||||||
|
array recv_like(
|
||||||
|
const array& x,
|
||||||
|
int src,
|
||||||
|
std::optional<Group> group_ /* = std::nullopt */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
return recv(x.shape(), x.dtype(), src, group_, s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -13,9 +13,29 @@ array all_sum(
|
|||||||
const array& x,
|
const array& x,
|
||||||
std::optional<Group> group = std::nullopt,
|
std::optional<Group> group = std::nullopt,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
array all_gather(
|
array all_gather(
|
||||||
const array& x,
|
const array& x,
|
||||||
std::optional<Group> group = std::nullopt,
|
std::optional<Group> group = std::nullopt,
|
||||||
StreamOrDevice S = {});
|
StreamOrDevice S = {});
|
||||||
|
|
||||||
|
array send(
|
||||||
|
const array& x,
|
||||||
|
int dst,
|
||||||
|
std::optional<Group> group = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array recv(
|
||||||
|
std::vector<int> shape,
|
||||||
|
Dtype dtype,
|
||||||
|
int src,
|
||||||
|
std::optional<Group> group = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
array recv_like(
|
||||||
|
const array& x,
|
||||||
|
int src,
|
||||||
|
std::optional<Group> group = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -35,7 +35,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
|
|||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return {{all_sum(inputs[0], group())}, axes};
|
return {{all_sum(inputs[0], group(), stream())}, axes};
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
@ -47,7 +47,7 @@ std::vector<array> AllReduce::jvp(
|
|||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
switch (reduce_type_) {
|
switch (reduce_type_) {
|
||||||
case Sum:
|
case Sum:
|
||||||
return {all_sum(tangents[0], group())};
|
return {all_sum(tangents[0], group(), stream())};
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
@ -75,14 +75,14 @@ void AllGather::eval_cpu(
|
|||||||
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
|
||||||
const std::vector<array>& inputs,
|
const std::vector<array>& inputs,
|
||||||
const std::vector<int>& axes) {
|
const std::vector<int>& axes) {
|
||||||
return {{all_gather(inputs[0], group())}, axes};
|
return {{all_gather(inputs[0], group(), stream())}, axes};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> AllGather::jvp(
|
std::vector<array> AllGather::jvp(
|
||||||
const std::vector<array>& primals,
|
const std::vector<array>& primals,
|
||||||
const std::vector<array>& tangents,
|
const std::vector<array>& tangents,
|
||||||
const std::vector<int>& argnums) {
|
const std::vector<int>& argnums) {
|
||||||
return {all_gather(tangents[0], group())};
|
return {all_gather(tangents[0], group(), stream())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> AllGather::vjp(
|
std::vector<array> AllGather::vjp(
|
||||||
@ -98,4 +98,29 @@ std::vector<array> AllGather::vjp(
|
|||||||
return {slice(cotangents[0], starts, stops)};
|
return {slice(cotangents[0], starts, stops)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Send::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
distributed::detail::send(group(), inputs[0], dst_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
return {{send(inputs[0], dst_, group(), stream())}, axes};
|
||||||
|
}
|
||||||
|
|
||||||
|
void Recv::eval_cpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
assert(inputs.size() == 0);
|
||||||
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
|
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
|
||||||
|
distributed::detail::recv(group(), outputs[0], src_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -97,4 +97,39 @@ class AllGather : public DistPrimitive {
|
|||||||
DEFINE_PRINT(AllGather);
|
DEFINE_PRINT(AllGather);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Send : public DistPrimitive {
|
||||||
|
public:
|
||||||
|
Send(Stream stream, Group group, int dst)
|
||||||
|
: DistPrimitive(stream, group), dst_(dst) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(Send);
|
||||||
|
|
||||||
|
private:
|
||||||
|
int dst_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Recv : public DistPrimitive {
|
||||||
|
public:
|
||||||
|
Recv(Stream stream, Group group, int src)
|
||||||
|
: DistPrimitive(stream, group), src_(src) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(Recv);
|
||||||
|
|
||||||
|
private:
|
||||||
|
int src_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace mlx::core::distributed
|
} // namespace mlx::core::distributed
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include <nanobind/stl/optional.h>
|
#include <nanobind/stl/optional.h>
|
||||||
#include <nanobind/stl/shared_ptr.h>
|
#include <nanobind/stl/shared_ptr.h>
|
||||||
#include <nanobind/stl/variant.h>
|
#include <nanobind/stl/variant.h>
|
||||||
|
#include <nanobind/stl/vector.h>
|
||||||
|
|
||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
@ -121,4 +122,90 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The concatenation of all ``x`` arrays.
|
array: The concatenation of all ``x`` arrays.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"send",
|
||||||
|
&distributed::send,
|
||||||
|
"x"_a,
|
||||||
|
"dst"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"group"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def send(x: array, dst: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Send an array from the current process to the process that has rank
|
||||||
|
``dst`` in the group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array.
|
||||||
|
dst (int): Rank of the destination process in the group.
|
||||||
|
group (Group): The group of processes that will participate in the
|
||||||
|
sned. If set to ``None`` the global group is used. Default:
|
||||||
|
``None``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: An empty array which when evaluated the send is performed.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"recv",
|
||||||
|
&distributed::recv,
|
||||||
|
"shape"_a,
|
||||||
|
"dtype"_a,
|
||||||
|
"src"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"group"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Recv an array with shape ``shape`` and dtype ``dtype`` from process
|
||||||
|
with rank ``src``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape (Tuple[int]): The shape of the array we are receiving.
|
||||||
|
dtype (Dtype): The data type of the array we are receiving.
|
||||||
|
src (int): Rank of the source process in the group.
|
||||||
|
group (Group): The group of processes that will participate in the
|
||||||
|
recv. If set to ``None`` the global group is used. Default:
|
||||||
|
``None``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The array that was received from ``src``.
|
||||||
|
)pbdoc");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"recv_like",
|
||||||
|
&distributed::recv_like,
|
||||||
|
"x"_a,
|
||||||
|
"src"_a,
|
||||||
|
nb::kw_only(),
|
||||||
|
"group"_a = nb::none(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def recv_like(x: array, src: int, *, group: Optional[Group] = None, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Recv an array with shape and type like ``x`` from process with rank
|
||||||
|
``src``.
|
||||||
|
|
||||||
|
It is equivalent to calling ``mx.distributed.recv(x.shape, x.dtype, src)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): An array defining the shape and dtype of the array we are
|
||||||
|
receiving.
|
||||||
|
src (int): Rank of the source process in the group.
|
||||||
|
group (Group): The group of processes that will participate in the
|
||||||
|
recv. If set to ``None`` the global group is used. Default:
|
||||||
|
``None``.
|
||||||
|
stream (Stream, optional): Stream or device. Defaults to ``None``
|
||||||
|
in which case the default stream of the default device is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The array that was received from ``src``.
|
||||||
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -93,6 +93,23 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertTrue(mx.all(z == z_target))
|
self.assertTrue(mx.all(z == z_target))
|
||||||
|
|
||||||
|
def test_send_recv(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
pairs = world.split(world.rank() // 2)
|
||||||
|
neighbor = (pairs.rank() + 1) % 2
|
||||||
|
send = pairs.rank() == 0
|
||||||
|
|
||||||
|
x = mx.ones(10)
|
||||||
|
for i in range(10):
|
||||||
|
if send:
|
||||||
|
mx.eval(mx.distributed.send(2 * x, neighbor, group=pairs))
|
||||||
|
else:
|
||||||
|
x = mx.distributed.recv_like(x, neighbor, group=pairs)
|
||||||
|
mx.eval(x)
|
||||||
|
send = not send
|
||||||
|
|
||||||
|
self.assertTrue(mx.all(x == (1024 if pairs.rank() == 0 else 512)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user