Adds send/recv ops in distributed (#1366)

This commit is contained in:
Angelos Katharopoulos
2024-08-26 23:01:37 -07:00
committed by GitHub
parent 1d94ac3f90
commit cdb59faea6
13 changed files with 345 additions and 19 deletions

View File

@@ -10,7 +10,7 @@
namespace mlx::core::distributed {
void signal_and_wait(const array& in, const array& out, const Stream s) {
void signal_and_wait(const array& in, const array& out, const Stream& s) {
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
@@ -81,4 +81,62 @@ void AllGather::eval_gpu(
signal_and_wait(in, out, stream());
}
void Send::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
// Schedule an async send on the comm stream
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
if (in.event().valid()) {
in.event().wait();
}
distributed::detail::send(group, in, dst);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a signal event for the input but not a wait since we don't need to
// wait on the output.
auto& s = stream();
auto& d = metal::device(s.device);
d.end_encoding(s.index);
auto command_buffer = d.get_command_buffer(s.index);
if (in.event().valid()) {
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(in.event().raw_event().get()),
in.event().value());
}
}
void Recv::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
// Schedule an async recv on the comm stream
auto task = [out = out, group = group(), src = src_]() mutable {
distributed::detail::recv(group, out, src);
out.event().signal();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
// Encode a wait event as there is no input for the recv to encode a signal.
auto& s = stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(out.event().raw_event().get()),
out.event().value());
}
} // namespace mlx::core::distributed

View File

@@ -126,6 +126,8 @@ NO_GPU_MULTI(CustomKernel)
namespace distributed {
NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed
} // namespace mlx::core

View File

@@ -50,17 +50,4 @@ struct Group {
*/
Group init(bool strict = false);
namespace detail {
/* Return the communication stream. */
Stream communication_stream();
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output);
/* Perform an all reduce sum operation */
void all_gather(Group group, const array& input, array& output);
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -12,7 +12,13 @@ Stream communication_stream();
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output);
/* Perform an all reduce sum operation */
/* Perform an all gather operation */
void all_gather(Group group, const array& input, array& output);
/** Send an array to the dst rank */
void send(Group group, const array& input, int dst);
/** Recv an array from the src rank */
void recv(Group group, array& out, int src);
} // namespace mlx::core::distributed::detail

View File

@@ -48,6 +48,8 @@ struct MPIWrapper {
LOAD_SYMBOL(MPI_Comm_free, comm_free);
LOAD_SYMBOL(MPI_Allreduce, all_reduce);
LOAD_SYMBOL(MPI_Allgather, all_gather);
LOAD_SYMBOL(MPI_Send, send);
LOAD_SYMBOL(MPI_Recv, recv);
// Objects
LOAD_SYMBOL(ompi_mpi_comm_world, comm_world_);
@@ -142,6 +144,8 @@ struct MPIWrapper {
MPI_Comm);
int (*comm_split)(MPI_Comm, int, int, MPI_Comm*);
int (*comm_free)(MPI_Comm*);
int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm);
int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*);
// Objects
MPI_Comm comm_world_;
@@ -285,6 +289,29 @@ void all_gather(Group group, const array& input_, array& output) {
to_comm(group));
}
void send(Group group, const array& input_, int dst) {
array input = ensure_row_contiguous(input_);
mpi().send(
input.data<void>(),
input.size(),
mpi().datatype(input),
dst,
0,
to_comm(group));
}
void recv(Group group, array& out, int src) {
MPI_Status status;
mpi().recv(
out.data<void>(),
out.size(),
mpi().datatype(out),
src,
MPI_ANY_TAG,
to_comm(group),
&status);
}
} // namespace detail
} // namespace mlx::core::distributed

View File

@@ -34,6 +34,8 @@ Stream communication_stream() {
void all_sum(Group group, const array& input, array& output) {}
void all_gather(Group group, const array& input, array& output) {}
void send(Group group, const array& input, int dst) {}
void recv(Group group, array& out, int src) {}
} // namespace detail

View File

@@ -1,5 +1,7 @@
// Copyright © 2024 Apple Inc.
#include <sstream>
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
@@ -57,4 +59,59 @@ array all_gather(
{x});
}
array send(
const array& x,
int dst,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
throw std::invalid_argument("Cannot send to a singleton group");
}
if (dst < 0 || dst >= group.size()) {
std::ostringstream msg;
msg << "Invalid destination=" << dst << " for a group of size "
<< group.size();
throw std::invalid_argument(msg.str());
}
return array(
{0}, int32, std::make_shared<Send>(to_stream(s), group, dst), {x});
}
array recv(
std::vector<int> shape,
Dtype dtype,
int src,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
throw std::invalid_argument("Cannot recv from a singleton group");
}
if (src < 0 || src >= group.size()) {
std::ostringstream msg;
msg << "Invalid source=" << src << " for a group of size " << group.size();
throw std::invalid_argument(msg.str());
}
return array(
std::move(shape),
std::move(dtype),
std::make_shared<Recv>(to_stream(s), group, src),
std::vector<array>{});
}
array recv_like(
const array& x,
int src,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
return recv(x.shape(), x.dtype(), src, group_, s);
}
} // namespace mlx::core::distributed

View File

@@ -13,9 +13,29 @@ array all_sum(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array all_gather(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice S = {});
array send(
const array& x,
int dst,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array recv(
std::vector<int> shape,
Dtype dtype,
int src,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array recv_like(
const array& x,
int src,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::distributed

View File

@@ -35,7 +35,7 @@ std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
const std::vector<int>& axes) {
switch (reduce_type_) {
case Sum:
return {{all_sum(inputs[0], group())}, axes};
return {{all_sum(inputs[0], group(), stream())}, axes};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
@@ -47,7 +47,7 @@ std::vector<array> AllReduce::jvp(
const std::vector<int>& argnums) {
switch (reduce_type_) {
case Sum:
return {all_sum(tangents[0], group())};
return {all_sum(tangents[0], group(), stream())};
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
@@ -75,14 +75,14 @@ void AllGather::eval_cpu(
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{all_gather(inputs[0], group())}, axes};
return {{all_gather(inputs[0], group(), stream())}, axes};
}
std::vector<array> AllGather::jvp(
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums) {
return {all_gather(tangents[0], group())};
return {all_gather(tangents[0], group(), stream())};
}
std::vector<array> AllGather::vjp(
@@ -98,4 +98,29 @@ std::vector<array> AllGather::vjp(
return {slice(cotangents[0], starts, stops)};
}
void Send::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
distributed::detail::send(group(), inputs[0], dst_);
}
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{send(inputs[0], dst_, group(), stream())}, axes};
}
void Recv::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_);
}
} // namespace mlx::core::distributed

View File

@@ -97,4 +97,39 @@ class AllGather : public DistPrimitive {
DEFINE_PRINT(AllGather);
};
class Send : public DistPrimitive {
public:
Send(Stream stream, Group group, int dst)
: DistPrimitive(stream, group), dst_(dst) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) override;
DEFINE_PRINT(Send);
private:
int dst_;
};
class Recv : public DistPrimitive {
public:
Recv(Stream stream, Group group, int src)
: DistPrimitive(stream, group), src_(src) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(Recv);
private:
int src_;
};
} // namespace mlx::core::distributed