Nccl reduce scatter, all gather (#2727)

* Added reduce scatter and all gather for nccl

* fix unused import, delete unused file

* small fix

* deleted useless condition

* fixed comments

* fix bug in eval_gpu, renamed to sum_scatter, fix docs

* final fix docs

* remove and

* Update mlx/distributed/mpi/mpi.cpp

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* fix broken set input output

* fixes set output

* typo

* fix typo

* no cpu, no gpu for reduce scatter

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Anastasiia Filippova
2025-11-05 17:21:11 +01:00
committed by GitHub
parent 761f901a41
commit 27778156dc
19 changed files with 351 additions and 311 deletions

View File

@@ -95,4 +95,9 @@ void Recv::eval_cpu(
distributed::detail::recv(group(), outputs[0], src_, stream());
}
void ReduceScatter::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
}
} // namespace mlx::core::distributed

View File

@@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
"Only all reduce sum, max, and min are supported.");
}
}
void AllGather::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
auto capture = encoder.capture_context();
distributed::detail::all_gather(group(), input, outputs[0], s);
}
void ReduceScatter::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
auto ensure_contiguous = [&s, &encoder](const array& x) {
if (x.flags().row_contiguous) {
return x;
} else {
array x_copy = contiguous_copy_gpu(x, s);
encoder.add_temporary(x_copy);
return x_copy;
}
};
auto input = ensure_contiguous(inputs[0]);
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
encoder.set_input_array(input);
encoder.set_output_array(outputs[0]);
auto capture = encoder.capture_context();
switch (reduce_type_) {
case Sum:
distributed::detail::sum_scatter(group(), input, outputs[0], s);
break;
default:
throw std::runtime_error("Only sum scatter is supported. ");
}
}
} // namespace mlx::core::distributed

View File

@@ -40,7 +40,6 @@ NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
namespace distributed {
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
} // namespace distributed

View File

@@ -30,4 +30,9 @@ void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
}
void ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error(
"[ReduceScatter::eval_gpu] has no GPU implementation.");
}
} // namespace mlx::core::distributed

View File

@@ -138,6 +138,7 @@ NO_CPU_MULTI(AllReduce)
NO_CPU_MULTI(AllGather)
NO_CPU_MULTI(Send)
NO_CPU_MULTI(Recv)
NO_CPU_MULTI(ReduceScatter)
} // namespace distributed
} // namespace mlx::core

View File

@@ -164,6 +164,7 @@ NO_GPU_MULTI(AllReduce)
NO_GPU_MULTI(AllGather)
NO_GPU_MULTI(Send)
NO_GPU_MULTI(Recv)
NO_GPU_MULTI(ReduceScatter)
} // namespace distributed
} // namespace mlx::core

View File

@@ -41,6 +41,14 @@ void recv(Group group, array& out, int src, Stream stream) {
group.raw_group()->recv(out, src, stream);
}
void sum_scatter(
Group group,
const array& input,
array& output,
Stream stream) {
group.raw_group()->sum_scatter(input, output, stream);
}
class EmptyGroup : public GroupImpl {
public:
Stream communication_stream(StreamOrDevice s) override {
@@ -85,6 +93,10 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void sum_scatter(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
};
} // namespace detail

View File

@@ -28,6 +28,8 @@ class GroupImpl {
virtual void recv(array& out, int src, Stream stream) = 0;
virtual void all_max(const array& input, array& output, Stream stream) = 0;
virtual void all_min(const array& input, array& output, Stream stream) = 0;
virtual void
sum_scatter(const array& input, array& output, Stream stream) = 0;
};
/* Define the MLX stream that the communication should happen in. */
@@ -51,4 +53,7 @@ void all_max(Group group, const array& input, array& output, Stream stream);
/** Min reduction */
void all_min(Group group, const array& input, array& output, Stream stream);
/** Reduce scatter with average operation */
void sum_scatter(Group group, const array& input, array& output, Stream stream);
} // namespace mlx::core::distributed::detail

View File

@@ -465,6 +465,10 @@ class MPIGroup : public GroupImpl {
});
}
void sum_scatter(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[mpi] sum_scatter not yet implemented.");
}
private:
MPI_Comm comm_;
bool global_;

View File

@@ -296,8 +296,17 @@ class NCCLGroup : public GroupImpl {
}
void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error(
"[nccl] All gather not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclAllGather(
input.data<T>(),
output.data<T>(),
input.size(),
dt,
comm_,
encoder.stream()));
});
}
void send(const array& input, int dst, Stream stream) override {
@@ -309,11 +318,24 @@ class NCCLGroup : public GroupImpl {
}
void all_max(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
});
}
void all_min(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
});
}
void sum_scatter(const array& input, array& output, Stream stream) override {
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
using T = typename decltype(type_tag)::type;
reduce_scatter_impl<T>(input, output, stream, dt, ncclSum);
});
}
template <typename T>
@@ -335,6 +357,25 @@ class NCCLGroup : public GroupImpl {
encoder.stream()));
}
template <typename T>
void reduce_scatter_impl(
const array& input,
array& output,
Stream stream,
ncclDataType_t dt,
ncclRedOp_t op) {
auto& encoder = cu::get_command_encoder(stream);
CHECK_NCCL(ncclReduceScatter(
input.data<T>(),
output.data<T>(),
output.size(),
dt,
op,
comm_,
encoder.stream()));
}
int rank_, size_;
std::string initMethod_;
ncclUniqueId uniqueId_;

View File

@@ -157,4 +157,30 @@ array recv_like(
return recv(x.shape(), x.dtype(), src, group_, s);
}
array sum_scatter(
const array& x,
std::optional<Group> group_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
auto group = to_group(group_);
if (group.size() == 1) {
return x;
}
if (x.shape()[0] % group.size() != 0) {
std::ostringstream msg;
msg << "[sum_scatter] Invalid shape=" << x.shape()
<< " for a group of size " << group.size()
<< ". The first dimension (axis 0) must be divisible by the group size.";
throw std::invalid_argument(msg.str());
}
auto result_shape = x.shape();
result_shape[0] /= group.size();
auto stream = detail::communication_stream(group, s);
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<ReduceScatter>(stream, group, ReduceScatter::Sum),
{x});
}
} // namespace mlx::core::distributed

View File

@@ -48,4 +48,9 @@ array all_min(
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
array sum_scatter(
const array& x,
std::optional<Group> group = std::nullopt,
StreamOrDevice s = {});
} // namespace mlx::core::distributed

View File

@@ -127,4 +127,30 @@ class Recv : public DistPrimitive {
int src_;
};
class ReduceScatter : public DistPrimitive {
public:
enum ReduceType { Sum, Min, Max };
ReduceScatter(Stream stream, Group group, ReduceType reduce_type)
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
const char* name() const override {
switch (reduce_type_) {
case Sum:
return "Sum ReduceScatter";
case Min:
return "Min ReduceScatter";
case Max:
return "Max ReduceScatter";
}
return "<unknwon ReduceScatter>";
}
private:
ReduceType reduce_type_;
};
} // namespace mlx::core::distributed

View File

@@ -731,6 +731,10 @@ class RingGroup : public GroupImpl {
});
}
void sum_scatter(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[ring] sum_scatter not supported.");
}
private:
template <typename T, typename ReduceOp>
void all_reduce(