From 29d9cd836a498bb0ed823a5ad2b8b5f092fe642d Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 28 Oct 2025 11:27:20 -0700 Subject: [PATCH] Make sure that there is space for work completions --- mlx/distributed/ibv/ibv.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlx/distributed/ibv/ibv.cpp b/mlx/distributed/ibv/ibv.cpp index 2a4c4644b..edcf52e49 100644 --- a/mlx/distributed/ibv/ibv.cpp +++ b/mlx/distributed/ibv/ibv.cpp @@ -744,6 +744,7 @@ class IBVGroup : public GroupImpl { char* our_data = out_ptr + rank_ * n_bytes; constexpr int64_t N = BUFFER_SIZE; constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * MAX_PEERS * 2; int64_t total = static_cast(n_bytes); int num_peers = size_ - 1; @@ -772,8 +773,8 @@ class IBVGroup : public GroupImpl { // // Keep going until we have no longer data in flight. while (in_flight > 0) { - ibv_wc wc[8]; - int n = cm_.poll(8, wc); + ibv_wc wc[WC_NUM]; + int n = cm_.poll(WC_NUM, wc); for (int i = 0; i < n; i++) { int work_type = wc[i].wr_id >> 16; int buff = (wc[i].wr_id >> 8) & 0xff; @@ -863,6 +864,7 @@ class IBVGroup : public GroupImpl { T* data = out_ptr; constexpr int64_t N = BUFFER_SIZE / sizeof(T); constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * MAX_PEERS * 2; int64_t total = static_cast(size); int num_peers = size_ - 1; @@ -900,8 +902,8 @@ class IBVGroup : public GroupImpl { // // If a receive is completed then advance the pointer of completed // receives. - ibv_wc wc[8]; - int n = cm_.poll(8, wc); + ibv_wc wc[WC_NUM]; + int n = cm_.poll(WC_NUM, wc); for (int i = 0; i < n; i++) { int work_type = wc[i].wr_id >> 16; int buff = (wc[i].wr_id >> 8) & 0xff;