Make sure that there is space for work completions

This commit is contained in:
Angelos Katharopoulos
2025-10-28 11:27:20 -07:00
parent 2d10020178
commit 29d9cd836a

View File

@@ -744,6 +744,7 @@ class IBVGroup : public GroupImpl {
char* our_data = out_ptr + rank_ * n_bytes; char* our_data = out_ptr + rank_ * n_bytes;
constexpr int64_t N = BUFFER_SIZE; constexpr int64_t N = BUFFER_SIZE;
constexpr int PIPELINE = 2; constexpr int PIPELINE = 2;
constexpr int WC_NUM = PIPELINE * MAX_PEERS * 2;
int64_t total = static_cast<int64_t>(n_bytes); int64_t total = static_cast<int64_t>(n_bytes);
int num_peers = size_ - 1; int num_peers = size_ - 1;
@@ -772,8 +773,8 @@ class IBVGroup : public GroupImpl {
// //
// Keep going until we have no longer data in flight. // Keep going until we have no longer data in flight.
while (in_flight > 0) { while (in_flight > 0) {
ibv_wc wc[8]; ibv_wc wc[WC_NUM];
int n = cm_.poll(8, wc); int n = cm_.poll(WC_NUM, wc);
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
int work_type = wc[i].wr_id >> 16; int work_type = wc[i].wr_id >> 16;
int buff = (wc[i].wr_id >> 8) & 0xff; int buff = (wc[i].wr_id >> 8) & 0xff;
@@ -863,6 +864,7 @@ class IBVGroup : public GroupImpl {
T* data = out_ptr; T* data = out_ptr;
constexpr int64_t N = BUFFER_SIZE / sizeof(T); constexpr int64_t N = BUFFER_SIZE / sizeof(T);
constexpr int PIPELINE = 2; constexpr int PIPELINE = 2;
constexpr int WC_NUM = PIPELINE * MAX_PEERS * 2;
int64_t total = static_cast<int64_t>(size); int64_t total = static_cast<int64_t>(size);
int num_peers = size_ - 1; int num_peers = size_ - 1;
@@ -900,8 +902,8 @@ class IBVGroup : public GroupImpl {
// //
// If a receive is completed then advance the pointer of completed // If a receive is completed then advance the pointer of completed
// receives. // receives.
ibv_wc wc[8]; ibv_wc wc[WC_NUM];
int n = cm_.poll(8, wc); int n = cm_.poll(WC_NUM, wc);
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
int work_type = wc[i].wr_id >> 16; int work_type = wc[i].wr_id >> 16;
int buff = (wc[i].wr_id >> 8) & 0xff; int buff = (wc[i].wr_id >> 8) & 0xff;