mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl * fix unused import, delete unused file * small fix * deleted useless condition * fixed comments * fix bug in eval_gpu, renamed to sum_scatter, fix docs * final fix docs * remove and * Update mlx/distributed/mpi/mpi.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix broken set input output * fixes set output * typo * fix typo * no cpu, no gpu for reduce scatter --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
committed by
GitHub
parent
761f901a41
commit
27778156dc
@@ -95,4 +95,9 @@ void Recv::eval_cpu(
|
||||
distributed::detail::recv(group(), outputs[0], src_, stream());
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_cpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
throw std::runtime_error("[ReduceScatter] Not implemented yet.");
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -53,4 +53,69 @@ void AllReduce::eval_gpu(
|
||||
"Only all reduce sum, max, and min are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto input = ensure_contiguous(inputs[0]);
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(outputs[0]);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
distributed::detail::all_gather(group(), input, outputs[0], s);
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& s = stream();
|
||||
auto& encoder = cu::get_command_encoder(s);
|
||||
|
||||
auto ensure_contiguous = [&s, &encoder](const array& x) {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
encoder.add_temporary(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
|
||||
auto input = ensure_contiguous(inputs[0]);
|
||||
outputs[0].set_data(allocator::malloc(outputs[0].nbytes()));
|
||||
|
||||
encoder.set_input_array(input);
|
||||
encoder.set_output_array(outputs[0]);
|
||||
|
||||
auto capture = encoder.capture_context();
|
||||
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
distributed::detail::sum_scatter(group(), input, outputs[0], s);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only sum scatter is supported. ");
|
||||
}
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -40,7 +40,6 @@ NO_GPU_MULTI(Eig)
|
||||
NO_GPU_MULTI(Eigh)
|
||||
|
||||
namespace distributed {
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
} // namespace distributed
|
||||
|
||||
@@ -30,4 +30,9 @@ void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void ReduceScatter::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error(
|
||||
"[ReduceScatter::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -138,6 +138,7 @@ NO_CPU_MULTI(AllReduce)
|
||||
NO_CPU_MULTI(AllGather)
|
||||
NO_CPU_MULTI(Send)
|
||||
NO_CPU_MULTI(Recv)
|
||||
NO_CPU_MULTI(ReduceScatter)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -164,6 +164,7 @@ NO_GPU_MULTI(AllReduce)
|
||||
NO_GPU_MULTI(AllGather)
|
||||
NO_GPU_MULTI(Send)
|
||||
NO_GPU_MULTI(Recv)
|
||||
NO_GPU_MULTI(ReduceScatter)
|
||||
} // namespace distributed
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -41,6 +41,14 @@ void recv(Group group, array& out, int src, Stream stream) {
|
||||
group.raw_group()->recv(out, src, stream);
|
||||
}
|
||||
|
||||
void sum_scatter(
|
||||
Group group,
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream) {
|
||||
group.raw_group()->sum_scatter(input, output, stream);
|
||||
}
|
||||
|
||||
class EmptyGroup : public GroupImpl {
|
||||
public:
|
||||
Stream communication_stream(StreamOrDevice s) override {
|
||||
@@ -85,6 +93,10 @@ class EmptyGroup : public GroupImpl {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
void sum_scatter(const array&, array&, Stream) override {
|
||||
throw std::runtime_error(
|
||||
"Communication not implemented in an empty distributed group.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@@ -28,6 +28,8 @@ class GroupImpl {
|
||||
virtual void recv(array& out, int src, Stream stream) = 0;
|
||||
virtual void all_max(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void all_min(const array& input, array& output, Stream stream) = 0;
|
||||
virtual void
|
||||
sum_scatter(const array& input, array& output, Stream stream) = 0;
|
||||
};
|
||||
|
||||
/* Define the MLX stream that the communication should happen in. */
|
||||
@@ -51,4 +53,7 @@ void all_max(Group group, const array& input, array& output, Stream stream);
|
||||
/** Min reduction */
|
||||
void all_min(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
/** Reduce scatter with average operation */
|
||||
void sum_scatter(Group group, const array& input, array& output, Stream stream);
|
||||
|
||||
} // namespace mlx::core::distributed::detail
|
||||
|
||||
@@ -465,6 +465,10 @@ class MPIGroup : public GroupImpl {
|
||||
});
|
||||
}
|
||||
|
||||
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[mpi] sum_scatter not yet implemented.");
|
||||
}
|
||||
|
||||
private:
|
||||
MPI_Comm comm_;
|
||||
bool global_;
|
||||
|
||||
@@ -296,8 +296,17 @@ class NCCLGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_gather(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error(
|
||||
"[nccl] All gather not supported in NCCL backend.");
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
CHECK_NCCL(ncclAllGather(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
input.size(),
|
||||
dt,
|
||||
comm_,
|
||||
encoder.stream()));
|
||||
});
|
||||
}
|
||||
|
||||
void send(const array& input, int dst, Stream stream) override {
|
||||
@@ -309,11 +318,24 @@ class NCCLGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void all_max(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[nccl] All max not supported in NCCL backend.");
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMax);
|
||||
});
|
||||
}
|
||||
|
||||
void all_min(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[nccl] All min not supported in NCCL backend.");
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
all_reduce_impl<T>(input, output, stream, dt, ncclMin);
|
||||
});
|
||||
}
|
||||
|
||||
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||
detail::dispatch_dtype(input, [&](auto type_tag, ncclDataType_t dt) {
|
||||
using T = typename decltype(type_tag)::type;
|
||||
reduce_scatter_impl<T>(input, output, stream, dt, ncclSum);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -335,6 +357,25 @@ class NCCLGroup : public GroupImpl {
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void reduce_scatter_impl(
|
||||
const array& input,
|
||||
array& output,
|
||||
Stream stream,
|
||||
ncclDataType_t dt,
|
||||
ncclRedOp_t op) {
|
||||
auto& encoder = cu::get_command_encoder(stream);
|
||||
|
||||
CHECK_NCCL(ncclReduceScatter(
|
||||
input.data<T>(),
|
||||
output.data<T>(),
|
||||
output.size(),
|
||||
dt,
|
||||
op,
|
||||
comm_,
|
||||
encoder.stream()));
|
||||
}
|
||||
|
||||
int rank_, size_;
|
||||
std::string initMethod_;
|
||||
ncclUniqueId uniqueId_;
|
||||
|
||||
@@ -157,4 +157,30 @@ array recv_like(
|
||||
return recv(x.shape(), x.dtype(), src, group_, s);
|
||||
}
|
||||
|
||||
array sum_scatter(
|
||||
const array& x,
|
||||
std::optional<Group> group_ /* = std::nullopt */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto group = to_group(group_);
|
||||
if (group.size() == 1) {
|
||||
return x;
|
||||
}
|
||||
if (x.shape()[0] % group.size() != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[sum_scatter] Invalid shape=" << x.shape()
|
||||
<< " for a group of size " << group.size()
|
||||
<< ". The first dimension (axis 0) must be divisible by the group size.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
auto result_shape = x.shape();
|
||||
result_shape[0] /= group.size();
|
||||
auto stream = detail::communication_stream(group, s);
|
||||
|
||||
return array(
|
||||
std::move(result_shape),
|
||||
x.dtype(),
|
||||
std::make_shared<ReduceScatter>(stream, group, ReduceScatter::Sum),
|
||||
{x});
|
||||
}
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -48,4 +48,9 @@ array all_min(
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
array sum_scatter(
|
||||
const array& x,
|
||||
std::optional<Group> group = std::nullopt,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -127,4 +127,30 @@ class Recv : public DistPrimitive {
|
||||
int src_;
|
||||
};
|
||||
|
||||
class ReduceScatter : public DistPrimitive {
|
||||
public:
|
||||
enum ReduceType { Sum, Min, Max };
|
||||
ReduceScatter(Stream stream, Group group, ReduceType reduce_type)
|
||||
: DistPrimitive(stream, group), reduce_type_(reduce_type) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||
override;
|
||||
|
||||
const char* name() const override {
|
||||
switch (reduce_type_) {
|
||||
case Sum:
|
||||
return "Sum ReduceScatter";
|
||||
case Min:
|
||||
return "Min ReduceScatter";
|
||||
case Max:
|
||||
return "Max ReduceScatter";
|
||||
}
|
||||
return "<unknwon ReduceScatter>";
|
||||
}
|
||||
|
||||
private:
|
||||
ReduceType reduce_type_;
|
||||
};
|
||||
} // namespace mlx::core::distributed
|
||||
|
||||
@@ -731,6 +731,10 @@ class RingGroup : public GroupImpl {
|
||||
});
|
||||
}
|
||||
|
||||
void sum_scatter(const array& input, array& output, Stream stream) override {
|
||||
throw std::runtime_error("[ring] sum_scatter not supported.");
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T, typename ReduceOp>
|
||||
void all_reduce(
|
||||
|
||||
Reference in New Issue
Block a user