mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Add a ring all gather (#1985)
This commit is contained in:
parent
25814a9458
commit
69e4dd506b
@ -603,8 +603,8 @@ class RingGroup : public GroupImpl {
|
|||||||
return size_;
|
return size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void all_sum(const array& input_, array& output, Stream stream) override {
|
void all_sum(const array& input, array& output, Stream stream) override {
|
||||||
SWITCH_TYPE(output, all_sum<T>(input_, output, stream));
|
SWITCH_TYPE(output, all_sum<T>(input, output, stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
@ -612,7 +612,39 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void all_gather(const array& input, array& output, Stream stream) override {
|
void all_gather(const array& input, array& output, Stream stream) override {
|
||||||
throw std::runtime_error("[ring] All gather not supported.");
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.set_output_array(output);
|
||||||
|
encoder.dispatch([input_ptr = input.data<char>(),
|
||||||
|
nbytes = input.nbytes(),
|
||||||
|
output_ptr = output.data<char>(),
|
||||||
|
this]() {
|
||||||
|
constexpr size_t min_send_size = 262144;
|
||||||
|
size_t n_gathers = std::max(
|
||||||
|
std::min(
|
||||||
|
sockets_right_.size() + sockets_left_.size(),
|
||||||
|
nbytes / min_send_size),
|
||||||
|
1UL);
|
||||||
|
size_t bytes_per_gather = ceildiv(nbytes, n_gathers);
|
||||||
|
std::vector<std::future<void>> all_gathers;
|
||||||
|
for (int i = 0; i < n_gathers; i++) {
|
||||||
|
auto offset = i * bytes_per_gather;
|
||||||
|
all_gathers.emplace_back(pool_.enqueue(std::bind(
|
||||||
|
&RingGroup::all_gather_impl,
|
||||||
|
this,
|
||||||
|
input_ptr + offset,
|
||||||
|
output_ptr + offset,
|
||||||
|
nbytes,
|
||||||
|
offset + bytes_per_gather > nbytes ? nbytes - offset
|
||||||
|
: bytes_per_gather,
|
||||||
|
sockets_right_[i / 2],
|
||||||
|
sockets_left_[i / 2],
|
||||||
|
(i % 2) ? -1 : 1)));
|
||||||
|
}
|
||||||
|
for (auto& f : all_gathers) {
|
||||||
|
f.wait();
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void send(const array& input, int dst, Stream stream) override {
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
@ -642,8 +674,7 @@ class RingGroup : public GroupImpl {
|
|||||||
encoder.dispatch(
|
encoder.dispatch(
|
||||||
[out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() {
|
[out_ptr = out.data<char>(), nbytes = out.nbytes(), src, this]() {
|
||||||
// NOTE: We 'll check the sockets with the opposite order of send so
|
// NOTE: We 'll check the sockets with the opposite order of send so
|
||||||
// that
|
// that they work even with 2 nodes where left and right is the same
|
||||||
// they work even with 2 nodes where left and right is the same
|
|
||||||
// neighbor.
|
// neighbor.
|
||||||
int right = (rank_ + 1) % size_;
|
int right = (rank_ + 1) % size_;
|
||||||
int left = (rank_ + size_ - 1) % size_;
|
int left = (rank_ + size_ - 1) % size_;
|
||||||
@ -827,6 +858,42 @@ class RingGroup : public GroupImpl {
|
|||||||
recvs[b].wait();
|
recvs[b].wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void all_gather_impl(
|
||||||
|
const char* input,
|
||||||
|
char* output,
|
||||||
|
size_t input_size,
|
||||||
|
size_t data_size,
|
||||||
|
int socket_right,
|
||||||
|
int socket_left,
|
||||||
|
int direction) {
|
||||||
|
// Choose which socket we send to and recv from
|
||||||
|
int socket_send = (direction < 0) ? socket_right : socket_left;
|
||||||
|
int socket_recv = (direction < 0) ? socket_left : socket_right;
|
||||||
|
|
||||||
|
// Initial segments
|
||||||
|
int send_segment = rank_;
|
||||||
|
int recv_segment = (rank_ + direction + size_) % size_;
|
||||||
|
|
||||||
|
// Copy our own segment in the output
|
||||||
|
std::memcpy(output + rank_ * input_size, input, data_size);
|
||||||
|
|
||||||
|
// Simple send/recv all gather. Possible performance improvement by
|
||||||
|
// splitting to multiple chunks and allowing send/recv to run a bit ahead.
|
||||||
|
// See all_sum_impl for an example.
|
||||||
|
for (int i = 0; i < size_ - 1; i++) {
|
||||||
|
auto sent = comm_.send(
|
||||||
|
socket_send, output + send_segment * input_size, data_size);
|
||||||
|
auto recvd = comm_.recv(
|
||||||
|
socket_recv, output + recv_segment * input_size, data_size);
|
||||||
|
|
||||||
|
send_segment = (send_segment + size_ + direction) % size_;
|
||||||
|
recv_segment = (recv_segment + size_ + direction) % size_;
|
||||||
|
|
||||||
|
sent.wait();
|
||||||
|
recvd.wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
|
send(const std::vector<int>& sockets, const char* data, size_t data_size) {
|
||||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||||
|
@ -56,6 +56,24 @@ class TestRingDistributed(mlx_tests.MLXTestCase):
|
|||||||
maxrelerror = ((y - z).abs() / z.abs()).max()
|
maxrelerror = ((y - z).abs() / z.abs()).max()
|
||||||
self.assertLessEqual(maxrelerror, rtol)
|
self.assertLessEqual(maxrelerror, rtol)
|
||||||
|
|
||||||
|
def test_all_gather(self):
|
||||||
|
world = mx.distributed.init()
|
||||||
|
dtypes = [
|
||||||
|
mx.int8,
|
||||||
|
mx.uint8,
|
||||||
|
mx.int16,
|
||||||
|
mx.uint16,
|
||||||
|
mx.int32,
|
||||||
|
mx.uint32,
|
||||||
|
mx.float32,
|
||||||
|
mx.complex64,
|
||||||
|
]
|
||||||
|
for dt in dtypes:
|
||||||
|
x = mx.ones((2, 2, 4), dtype=dt)
|
||||||
|
y = mx.distributed.all_gather(x)
|
||||||
|
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||||
|
self.assertTrue(mx.all(y == 1))
|
||||||
|
|
||||||
def test_send_recv(self):
|
def test_send_recv(self):
|
||||||
world = mx.distributed.init()
|
world = mx.distributed.init()
|
||||||
dtypes = [
|
dtypes = [
|
||||||
|
Loading…
Reference in New Issue
Block a user