From 45727b0c027783c121efbb3c417be1f9091899da Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 29 Oct 2025 14:09:25 -0700 Subject: [PATCH] Add send/recv --- mlx/distributed/ibv/ibv.cpp | 109 +++++++++++++++++++++++++++++++++++- 1 file changed, 107 insertions(+), 2 deletions(-) diff --git a/mlx/distributed/ibv/ibv.cpp b/mlx/distributed/ibv/ibv.cpp index edcf52e49..866147617 100644 --- a/mlx/distributed/ibv/ibv.cpp +++ b/mlx/distributed/ibv/ibv.cpp @@ -816,9 +816,114 @@ class IBVGroup : public GroupImpl { }); } - void send(const array& input, int dst, Stream stream) override {} + void send(const array& input, int dst, Stream stream) override { + auto data = input.data(); + int64_t n_bytes = input.nbytes(); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.dispatch([data, n_bytes, dst, this]() { + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE; + constexpr int N = BUFFER_SIZE; - void recv(array& out, int src, Stream stream) override {} + int in_flight = 0; + int64_t read_offset = 0; + + // Prefill the pipeline + int buff = 0; + while (read_offset < n_bytes && buff < PIPELINE) { + std::copy( + data + read_offset, + data + std::min(read_offset + N, n_bytes), + cm_.send_buffer(buff).begin()); + cm_.send_to(dst, buff); + + buff++; + read_offset += N; + in_flight++; + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a send was completed and we have more data to send then go ahead + // and send them. + ibv_wc wc[WC_NUM]; + int n = cm_.poll(WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + if (read_offset < n_bytes) { + std::copy( + data + read_offset, + data + std::min(read_offset + N, n_bytes), + cm_.send_buffer(buff).begin()); + cm_.send_to(dst, buff); + + read_offset += N; + in_flight++; + } + } + } + }); + } + + void recv(array& out, int src, Stream stream) override { + auto data = out.data(); + int64_t n_bytes = out.nbytes(); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + encoder.dispatch([data, n_bytes, src, this]() { + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE; + constexpr int N = BUFFER_SIZE; + + int in_flight = 0; + int64_t write_offset = 0; + + // Prefill the pipeline + int buff = 0; + while (write_offset < n_bytes && buff < PIPELINE) { + cm_.recv_from(src, buff); + + in_flight++; + buff++; + } + + // Main loop + while (in_flight > 0) { + // Poll the hardware for completions. + // + // If a recv was completed copy it to the output and if we have more + // data to fetch post another recv. + ibv_wc wc[WC_NUM]; + int n = cm_.poll(WC_NUM, wc); + for (int i = 0; i < n; i++) { + int buff = (wc[i].wr_id >> 8) & 0xff; + int rank = wc[i].wr_id & 0xff; + + in_flight--; + + std::copy( + cm_.buffer(src, buff).begin(), + cm_.buffer(src, buff).begin() + + std::min(n_bytes - write_offset, static_cast(N)), + data + write_offset); + write_offset += N; + + if (write_offset + (PIPELINE - 1) * N < n_bytes) { + cm_.recv_from(src, buff); + + in_flight++; + } + } + } + }); + } std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[ibv] Group split not supported.");