redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -6,31 +6,25 @@
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/distributed/ring/ring.h"
#include "mlx/scheduler.h"
namespace mlx::core::distributed {
namespace detail {
Stream communication_stream() {
static Stream comm_stream = new_stream(Device::cpu);
return comm_stream;
void all_sum(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_sum(input, output, stream);
}
void all_sum(Group group, const array& input, array& output) {
group.raw_group()->all_sum(input, output);
void all_gather(Group group, const array& input, array& output, Stream stream) {
group.raw_group()->all_gather(input, output, stream);
}
void all_gather(Group group, const array& input, array& output) {
group.raw_group()->all_gather(input, output);
void send(Group group, const array& input, int dst, Stream stream) {
group.raw_group()->send(input, dst, stream);
}
void send(Group group, const array& input, int dst) {
group.raw_group()->send(input, dst);
}
void recv(Group group, array& out, int src) {
group.raw_group()->recv(out, src);
void recv(Group group, array& out, int src, Stream stream) {
group.raw_group()->recv(out, src, stream);
}
class EmptyGroup : public GroupImpl {
@@ -47,19 +41,19 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error("Cannot split the distributed group further.");
}
void all_sum(const array& input, array& output) override {
void all_sum(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void all_gather(const array& input, array& output) override {
void all_gather(const array&, array&, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void send(const array& input, int dst) override {
void send(const array&, int, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void recv(array& out, int src) override {
void recv(array&, int, Stream) override {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
@@ -122,10 +116,6 @@ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
backends.insert({"any", group});
}
backends.insert({std::move(bk_), group});
// Ensure the communication stream is alive before
// the graph is evaluated
detail::communication_stream();
return Group(group);
}

View File

@@ -17,25 +17,22 @@ class GroupImpl {
virtual int size() = 0;
virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
virtual void all_sum(const array& input, array& output) = 0;
virtual void all_gather(const array& input, array& output) = 0;
virtual void send(const array& input, int dst) = 0;
virtual void recv(array& out, int src) = 0;
virtual void all_sum(const array& input, array& output, Stream stream) = 0;
virtual void all_gather(const array& input, array& output, Stream stream) = 0;
virtual void send(const array& input, int dst, Stream stream) = 0;
virtual void recv(array& out, int src, Stream stream) = 0;
};
/* Return the communication stream. */
Stream communication_stream();
/* Perform an all reduce sum operation */
void all_sum(Group group, const array& input, array& output);
void all_sum(Group group, const array& input, array& output, Stream stream);
/* Perform an all gather operation */
void all_gather(Group group, const array& input, array& output);
void all_gather(Group group, const array& input, array& output, Stream stream);
/** Send an array to the dst rank */
void send(Group group, const array& input, int dst);
void send(Group group, const array& input, int dst, Stream stream);
/** Recv an array from the src rank */
void recv(Group group, array& out, int src);
void recv(Group group, array& out, int src, Stream stream);
} // namespace mlx::core::distributed::detail

View File

@@ -3,11 +3,10 @@
#include <dlfcn.h>
#include <mpi.h>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/distributed/mpi/mpi.h"
#include "mlx/scheduler.h"
#define LOAD_SYMBOL(symbol, variable) \
{ \
@@ -25,16 +24,6 @@ using GroupImpl = mlx::core::distributed::detail::GroupImpl;
namespace {
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
template <typename T>
void simple_sum(
void* input,
@@ -281,9 +270,12 @@ class MPIGroup : public GroupImpl {
return std::make_shared<MPIGroup>(new_comm, false);
}
void all_sum(const array& input_, array& output) override {
array input = ensure_row_contiguous(input_);
mpi().all_reduce(
void all_sum(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_reduce,
(input.data<void>() == output.data<void>()) ? MPI_IN_PLACE
: input.data<void>(),
output.data<void>(),
@@ -293,9 +285,12 @@ class MPIGroup : public GroupImpl {
comm_);
}
void all_gather(const array& input_, array& output) override {
array input = ensure_row_contiguous(input_);
mpi().all_gather(
void all_gather(const array& input, array& output, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.set_output_array(output);
encoder.dispatch(
mpi().all_gather,
input.data<void>(),
input.size(),
mpi().datatype(input),
@@ -305,22 +300,30 @@ class MPIGroup : public GroupImpl {
comm_);
}
void send(const array& input_, int dst) override {
array input = ensure_row_contiguous(input_);
mpi().send(
input.data<void>(), input.size(), mpi().datatype(input), dst, 0, comm_);
void send(const array& input, int dst, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.dispatch(
mpi().send,
input.data<void>(),
input.size(),
mpi().datatype(input),
dst,
0,
comm_);
}
void recv(array& out, int src) override {
MPI_Status status;
mpi().recv(
out.data<void>(),
out.size(),
mpi().datatype(out),
src,
MPI_ANY_TAG,
comm_,
&status);
void recv(array& out, int src, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([out_ptr = out.data<void>(),
out_size = out.size(),
out_type = mpi().datatype(out),
src,
comm = comm_]() {
MPI_Status status;
mpi().recv(out_ptr, out_size, out_type, src, MPI_ANY_TAG, comm, &status);
});
}
private:

View File

@@ -28,11 +28,11 @@ array all_sum(
if (group.size() == 1) {
return x;
}
return array(
x.shape(),
x.dtype(),
std::make_shared<AllReduce>(to_stream(s), group, AllReduce::Sum),
std::make_shared<AllReduce>(
to_stream(s, Device::cpu), group, AllReduce::Sum),
{x});
}
@@ -55,7 +55,7 @@ array all_gather(
return array(
std::move(result_shape),
x.dtype(),
std::make_shared<AllGather>(to_stream(s), group),
std::make_shared<AllGather>(to_stream(s, Device::cpu), group),
{x});
}
@@ -80,7 +80,7 @@ array send(
return array(
x.shape(),
x.dtype(),
std::make_shared<Send>(to_stream(s), group, dst),
std::make_shared<Send>(to_stream(s, Device::cpu), group, dst),
{x});
}
@@ -105,7 +105,7 @@ array recv(
return array(
std::move(shape),
std::move(dtype),
std::make_shared<Recv>(to_stream(s), group, src),
std::make_shared<Recv>(to_stream(s, Device::cpu), group, src),
std::vector<array>{});
}

View File

@@ -3,34 +3,12 @@
#include <cassert>
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/ops.h"
namespace mlx::core::distributed {
void AllReduce::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
if (inputs[0].is_donatable()) {
outputs[0].copy_shared_buffer(inputs[0]);
} else {
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
}
switch (reduce_type_) {
case Sum:
distributed::detail::all_sum(group(), inputs[0], outputs[0]);
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
}
std::pair<std::vector<array>, std::vector<int>> AllReduce::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -62,17 +40,6 @@ std::vector<array> AllReduce::vjp(
return cotangents;
}
void AllGather::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::all_gather(group(), inputs[0], outputs[0]);
}
std::pair<std::vector<array>, std::vector<int>> AllGather::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
@@ -99,30 +66,10 @@ std::vector<array> AllGather::vjp(
return {slice(cotangents[0], starts, stops)};
}
void Send::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
distributed::detail::send(group(), inputs[0], dst_);
move_or_copy(inputs[0], outputs[0]);
}
std::pair<std::vector<array>, std::vector<int>> Send::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
return {{send(inputs[0], dst_, group(), stream())}, axes};
}
void Recv::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
outputs[0].set_data(allocator::malloc_or_wait(outputs[0].nbytes()));
distributed::detail::recv(group(), outputs[0], src_);
}
} // namespace mlx::core::distributed

View File

@@ -18,7 +18,7 @@
#include <json.hpp>
#include "mlx/backend/cpu/copy.h"
#include "mlx/backend/cpu/encoder.h"
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/distributed_impl.h"
#include "mlx/threadpool.h"
@@ -140,8 +140,8 @@ class SocketThread {
}
template <typename T>
std::future<void> send(T* buffer, size_t size) {
return send_impl(reinterpret_cast<char*>(buffer), size * sizeof(T));
std::future<void> send(const T* buffer, size_t size) {
return send_impl(reinterpret_cast<const char*>(buffer), size * sizeof(T));
}
template <typename T>
@@ -160,7 +160,7 @@ class SocketThread {
std::promise<void> promise;
};
std::future<void> send_impl(char* buffer, size_t size) {
std::future<void> send_impl(const char* buffer, size_t size) {
std::promise<void> send_completed_promise;
auto send_completed_future = send_completed_promise.get_future();
if (size == 0) {
@@ -170,8 +170,8 @@ class SocketThread {
{
std::unique_lock lock(queue_mutex_);
sends_.emplace_back(
SocketTask(buffer, size, std::move(send_completed_promise)));
sends_.emplace_back(SocketTask(
const_cast<char*>(buffer), size, std::move(send_completed_promise)));
}
condition_.notify_one();
return send_completed_future;
@@ -503,16 +503,6 @@ std::vector<int> make_connections(
return sockets;
}
array ensure_row_contiguous(const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy(arr, arr_copy, CopyType::General);
return arr_copy;
}
}
template <typename T>
void sum_inplace(const T* input, T* output, size_t N) {
while (N-- > 0) {
@@ -613,117 +603,131 @@ class RingGroup : public GroupImpl {
return size_;
}
void all_sum(const array& input_, array& output) override {
SWITCH_TYPE(output, all_sum<T>(input_, output));
void all_sum(const array& input_, array& output, Stream stream) override {
SWITCH_TYPE(output, all_sum<T>(input_, output, stream));
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[ring] Group split not supported.");
}
void all_gather(const array& input, array& output) override {
void all_gather(const array& input, array& output, Stream stream) override {
throw std::runtime_error("[ring] All gather not supported.");
}
void send(const array& input_, int dst) override {
// Make sure that the input is row contiguous
array input = ensure_row_contiguous(input_);
int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_;
if (dst == right) {
send(sockets_right_, input.data<char>(), input.nbytes());
} else if (dst == left) {
send(sockets_left_, input.data<char>(), input.nbytes());
} else {
std::ostringstream msg;
msg << "[ring] Send only supported to direct neighbors "
<< "but tried to send to " << dst << " from " << rank_ << std::endl;
throw std::runtime_error(msg.str());
}
void send(const array& input, int dst, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.dispatch(
[input_ptr = input.data<char>(), nbytes = input.nbytes(), dst, this]() {
int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_;
if (dst == right) {
send(sockets_right_, input_ptr, nbytes);
} else if (dst == left) {
send(sockets_left_, input_ptr, nbytes);
} else {
std::ostringstream msg;
msg << "[ring] Send only supported to direct neighbors "
<< "but tried to send to " << dst << " from " << rank_
<< std::endl;
throw std::runtime_error(msg.str());
}
});
}
void recv(array& out, int src) override {
// NOTE: We 'll check the sockets with the opposite order of send so that
// they work even with 2 nodes where left and right is the same
// neighbor.
int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_;
if (src == left) {
recv(sockets_left_, out.data<char>(), out.nbytes());
} else if (src == right) {
recv(sockets_right_, out.data<char>(), out.nbytes());
} else {
std::ostringstream msg;
msg << "[ring] Recv only supported from direct neighbors "
<< "but tried to recv from " << src << " to " << rank_ << std::endl;
throw std::runtime_error(msg.str());
}
void recv(array& out, int src, Stream stream) override {
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch(
[out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() {
// NOTE: We 'll check the sockets with the opposite order of send so
// that
// they work even with 2 nodes where left and right is the same
// neighbor.
int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_;
if (src == left) {
recv(sockets_left_, out_ptr, nbytes);
} else if (src == right) {
recv(sockets_right_, out_ptr, nbytes);
} else {
std::ostringstream msg;
msg << "[ring] Recv only supported from direct neighbors "
<< "but tried to recv from " << src << " to " << rank_
<< std::endl;
throw std::runtime_error(msg.str());
}
});
}
private:
template <typename T>
void all_sum(const array& input_, array& output) {
// Make sure that the input is row contiguous
array input = ensure_row_contiguous(input_);
void all_sum(const array& input, array& output, Stream stream) {
auto in_ptr = input.data<char>();
auto out_ptr = output.data<char>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(output);
encoder.dispatch([in_ptr, out_ptr, size = input.size(), this]() {
// If the input data cannot be split into size_ segments then copy it and
// all reduce a local buffer prefilled with 0s.
size_t nbytes = size * sizeof(T);
if (size < size_) {
// TODO: Maybe allocate dynamically so we don't have the constraint
// below?
if (sizeof(T) * size_ > 1024) {
std::ostringstream msg;
msg << "Can't perform the ring all reduce of " << size
<< " elements with a ring of size " << size_;
throw std::runtime_error(msg.str());
}
// If the input data cannot be split into size_ segments then copy it and
// all reduce a local buffer prefilled with 0s.
if (input.size() < size_) {
// TODO: Maybe allocate dynamically so we don't have the constraint
// below?
if (input.itemsize() * size_ > 1024) {
std::ostringstream msg;
msg << "Can't perform the ring all reduce of " << output.size()
<< " elements with a ring of size " << size_;
throw std::runtime_error(msg.str());
char buffer[1024];
std::memset(buffer, 0, size_ * sizeof(T));
std::memcpy(buffer, in_ptr, nbytes);
all_sum_impl<T>(
reinterpret_cast<T*>(buffers_.data()),
reinterpret_cast<T*>(buffer),
size_,
sockets_right_[0],
sockets_left_[0],
-1);
std::memcpy(out_ptr, buffer, nbytes);
return;
}
char buffer[1024];
std::memset(buffer, 0, size_ * input.itemsize());
std::memcpy(buffer, input.data<char>(), input.nbytes());
all_sum_impl<T>(
reinterpret_cast<T*>(buffers_.data()),
reinterpret_cast<T*>(buffer),
size_,
sockets_right_[0],
sockets_left_[0],
-1);
std::memcpy(output.data<char>(), buffer, output.nbytes());
return;
}
// If not inplace all reduce then copy the input to the output first
if (in_ptr != out_ptr) {
std::memcpy(out_ptr, in_ptr, nbytes);
}
// If not inplace all reduce then copy the input to the output first
if (input.data<void>() != output.data<void>()) {
std::memcpy(output.data<char>(), input.data<char>(), input.nbytes());
}
// Split the all reduces so that each member has at least 1 buffer to
// send/recv per segment.
constexpr size_t min_send_size = 262144;
size_t n_reduces = std::max(
std::min(
sockets_right_.size() + sockets_left_.size(),
nbytes / (size_ * min_send_size)),
1UL);
size_t step = ceildiv(size, n_reduces);
std::vector<std::future<void>> all_sums;
// Split the all reduces so that each member has at least 1 buffer to
// send/recv per segment.
constexpr size_t min_send_size = 262144;
size_t n_reduces = std::max(
std::min(
sockets_right_.size() + sockets_left_.size(),
output.nbytes() / (size_ * min_send_size)),
1UL);
size_t step = ceildiv(output.size(), n_reduces);
std::vector<std::future<void>> all_sums;
for (int i = 0; i < n_reduces; i++) {
all_sums.emplace_back(pool_.enqueue(std::bind(
&RingGroup::all_sum_impl<T>,
this,
reinterpret_cast<T*>(
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
output.data<T>() + i * step,
std::min(output.size(), (i + 1) * step) - i * step,
sockets_right_[i / 2],
sockets_left_[i / 2],
(i % 2) ? -1 : 1)));
}
for (auto& f : all_sums) {
f.wait();
}
for (int i = 0; i < n_reduces; i++) {
all_sums.emplace_back(pool_.enqueue(std::bind(
&RingGroup::all_sum_impl<T>,
this,
reinterpret_cast<T*>(
buffers_.data() + i * ALL_SUM_SIZE * ALL_SUM_BUFFERS),
reinterpret_cast<T*>(out_ptr) + i * step,
std::min(size, (i + 1) * step) - i * step,
sockets_right_[i / 2],
sockets_left_[i / 2],
(i % 2) ? -1 : 1)));
}
for (auto& f : all_sums) {
f.wait();
}
});
}
template <typename T>
@@ -823,7 +827,8 @@ class RingGroup : public GroupImpl {
recvs[b].wait();
}
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
void
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
std::vector<std::future<void>> sends;
for (int i = 0; i < sockets.size(); i++) {