mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add working reduce and semi-working all gather
This commit is contained in:
@@ -70,9 +70,8 @@ struct Destination {
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Destination& dst) {
|
||||
os << dst.local_id << " " << dst.queue_pair_number
|
||||
<< " " << dst.packet_sequence_number << " "
|
||||
<< dst.global_identifier;
|
||||
os << dst.local_id << " " << dst.queue_pair_number << " "
|
||||
<< dst.packet_sequence_number << " " << dst.global_identifier;
|
||||
return os;
|
||||
}
|
||||
|
||||
@@ -378,9 +377,10 @@ class SideChannel {
|
||||
sockets_.push_back(server.accept(IBV_TAG));
|
||||
}
|
||||
|
||||
std::vector<int> ranks(size-1);
|
||||
std::vector<int> ranks(size - 1);
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
sockets_[i].recv(IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));
|
||||
sockets_[i].recv(
|
||||
IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));
|
||||
ranks[i]--;
|
||||
}
|
||||
for (int i = 0; i < size - 1; i++) {
|
||||
@@ -739,115 +739,75 @@ class IBVGroup : public GroupImpl {
|
||||
// Copy our data to the appropriate place
|
||||
std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes);
|
||||
|
||||
// Fully connected all gather
|
||||
char* data = out_ptr;
|
||||
char* our_data = out_ptr + rank_ * n_bytes;
|
||||
constexpr int64_t N = BUFFER_SIZE;
|
||||
constexpr int PIPELINE = 2;
|
||||
int64_t total = static_cast<int64_t>(n_bytes);
|
||||
int64_t offset = N;
|
||||
int a = 0, b = 1;
|
||||
int num_peers = size_ - 1;
|
||||
|
||||
int mask_init = 1 << rank_;
|
||||
int mask_target = (1 << size_) - 1;
|
||||
// Counters to maintain the state of transfers
|
||||
int in_flight = 0;
|
||||
int read_offset = 0;
|
||||
int completed_send_count[PIPELINE] = {0};
|
||||
int write_offset[MAX_PEERS] = {0};
|
||||
|
||||
post_recv_all(a);
|
||||
std::copy(
|
||||
in_ptr,
|
||||
in_ptr + std::min(N, total),
|
||||
cm_.send_buffer(a).begin<char>());
|
||||
post_send_all(a);
|
||||
|
||||
int mask_a_send = mask_init;
|
||||
int mask_a_recv = mask_init;
|
||||
int mask_b_send = mask_init;
|
||||
int mask_b_recv = mask_init;
|
||||
|
||||
while (offset < total) {
|
||||
// While the previous chunk is in flight copy to the next send buffer
|
||||
// Prefill the pipeline
|
||||
int buff = 0;
|
||||
while (read_offset < total && buff < PIPELINE) {
|
||||
post_recv_all(buff);
|
||||
std::copy(
|
||||
in_ptr + offset,
|
||||
in_ptr + std::min(offset + N, total),
|
||||
cm_.send_buffer(b).begin<char>());
|
||||
our_data + read_offset,
|
||||
our_data + std::min(read_offset + N, total),
|
||||
cm_.send_buffer(buff).begin<char>());
|
||||
post_send_all(buff);
|
||||
|
||||
// Send if the previous send is already done
|
||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (mask_a_send & m) {
|
||||
cm_.send_to(i, b);
|
||||
}
|
||||
}
|
||||
|
||||
// Recv the next buffer if the previous one is already done
|
||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (mask_a_recv & m) {
|
||||
cm_.recv_from(i, b);
|
||||
std::copy(
|
||||
cm_.buffer(i, a).begin<char>(),
|
||||
cm_.buffer(i, a).begin<char>() + std::min(N, total - offset),
|
||||
out_ptr + i * n_bytes + std::max(offset - N, 0LL));
|
||||
}
|
||||
}
|
||||
|
||||
// Loop until this chunk is all done
|
||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
if (work_type == SEND_WR) {
|
||||
if (buff == a) {
|
||||
cm_.send_to(rank, b);
|
||||
mask_a_send |= 1 << rank;
|
||||
} else {
|
||||
mask_b_send |= 1 << rank;
|
||||
}
|
||||
} else {
|
||||
if (buff == a) {
|
||||
cm_.recv_from(rank, b);
|
||||
mask_a_recv |= 1 << rank;
|
||||
std::copy(
|
||||
cm_.buffer(rank, a).begin<char>(),
|
||||
cm_.buffer(rank, a).begin<char>() +
|
||||
std::min(N, total - offset),
|
||||
out_ptr + rank * n_bytes + std::max(offset - N, 0LL));
|
||||
} else {
|
||||
mask_b_recv |= 1 << rank;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::swap(a, b);
|
||||
mask_a_send = mask_b_send;
|
||||
mask_a_recv = mask_b_recv;
|
||||
mask_b_send = mask_init;
|
||||
mask_b_recv = mask_init;
|
||||
offset += N;
|
||||
buff++;
|
||||
in_flight += 2 * num_peers;
|
||||
read_offset += N;
|
||||
}
|
||||
|
||||
{
|
||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
// Main loop
|
||||
//
|
||||
// Keep going until we have no longer data in flight.
|
||||
while (in_flight > 0) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
if (work_type == SEND_WR) {
|
||||
mask_a_send |= 1 << rank;
|
||||
} else {
|
||||
mask_a_recv |= 1 << rank;
|
||||
in_flight--;
|
||||
|
||||
// Send completed. If all sends completed then send the next chunk.
|
||||
if (work_type == SEND_WR && read_offset < total) {
|
||||
completed_send_count[buff]++;
|
||||
if (completed_send_count[buff] == num_peers) {
|
||||
std::copy(
|
||||
cm_.buffer(rank, a).begin<char>(),
|
||||
cm_.buffer(rank, a).begin<char>() +
|
||||
std::min(N, total + N - offset),
|
||||
out_ptr + rank * n_bytes + offset - N);
|
||||
our_data + read_offset,
|
||||
our_data + std::min(read_offset + N, total),
|
||||
cm_.send_buffer(buff).begin<char>());
|
||||
post_send_all(buff);
|
||||
|
||||
completed_send_count[buff] = 0;
|
||||
in_flight += num_peers;
|
||||
read_offset += N;
|
||||
}
|
||||
}
|
||||
|
||||
// Recv completed. If we have more chunks then post another recv.
|
||||
else if (work_type == RECV_WR) {
|
||||
std::copy(
|
||||
cm_.buffer(rank, buff).begin<char>(),
|
||||
cm_.buffer(rank, buff).begin<char>() +
|
||||
std::min(N, total - write_offset[rank]),
|
||||
data + rank * n_bytes + write_offset[rank]);
|
||||
write_offset[rank] += N;
|
||||
if (write_offset[rank] + N * (PIPELINE - 1) < total) {
|
||||
cm_.recv_from(rank, buff);
|
||||
in_flight++;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -902,112 +862,99 @@ class IBVGroup : public GroupImpl {
|
||||
// Fully connected all reduce
|
||||
T* data = out_ptr;
|
||||
constexpr int64_t N = BUFFER_SIZE / sizeof(T);
|
||||
constexpr int PIPELINE = 2;
|
||||
int64_t total = static_cast<int64_t>(size);
|
||||
int64_t offset = N;
|
||||
int a = 0, b = 1;
|
||||
int num_peers = size_ - 1;
|
||||
|
||||
int mask_init = 1 << rank_;
|
||||
int mask_target = (1 << size_) - 1;
|
||||
// Counters to maintain the state of transfers
|
||||
int in_flight = 0;
|
||||
int read_offset = 0;
|
||||
int completed_send_count[PIPELINE] = {0};
|
||||
int completed_recv_begin[MAX_PEERS] = {0};
|
||||
int completed_recv_end[MAX_PEERS] = {0};
|
||||
|
||||
// Handle the first piece of data.
|
||||
post_recv_all(a);
|
||||
std::copy(data, data + std::min(N, total), cm_.send_buffer(a).begin<T>());
|
||||
post_send_all(a);
|
||||
|
||||
int mask_a_send = mask_init;
|
||||
int mask_a_recv = mask_init;
|
||||
int mask_b_send = mask_init;
|
||||
int mask_b_recv = mask_init;
|
||||
|
||||
while (offset < total) {
|
||||
// While the previous chunk is in flight copy to the next send buffer
|
||||
// Prefill the pipeline
|
||||
int buff = 0;
|
||||
while (read_offset < total && buff < PIPELINE) {
|
||||
post_recv_all(buff);
|
||||
std::copy(
|
||||
data + offset,
|
||||
data + std::min(offset + N, total),
|
||||
cm_.send_buffer(b).begin<T>());
|
||||
data + read_offset,
|
||||
data + std::min(read_offset + N, total),
|
||||
cm_.send_buffer(buff).begin<T>());
|
||||
post_send_all(buff);
|
||||
|
||||
// Send if the previous send is already done
|
||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (mask_a_send & m) {
|
||||
cm_.send_to(i, b);
|
||||
}
|
||||
}
|
||||
|
||||
// Recv the next buffer if the previous one is already done
|
||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
||||
if (i == rank_) {
|
||||
continue;
|
||||
}
|
||||
if (mask_a_recv & m) {
|
||||
cm_.recv_from(i, b);
|
||||
reduce_op(
|
||||
cm_.buffer(i, a).begin<T>(),
|
||||
data + std::max(offset - N, 0LL),
|
||||
std::min(N, total - offset));
|
||||
}
|
||||
}
|
||||
|
||||
// Loop until this chunk is all done
|
||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
if (work_type == SEND_WR) {
|
||||
if (buff == a) {
|
||||
cm_.send_to(rank, b);
|
||||
mask_a_send |= 1 << rank;
|
||||
} else {
|
||||
mask_b_send |= 1 << rank;
|
||||
}
|
||||
} else {
|
||||
if (buff == a) {
|
||||
cm_.recv_from(rank, b);
|
||||
mask_a_recv |= 1 << rank;
|
||||
reduce_op(
|
||||
cm_.buffer(rank, a).begin<T>(),
|
||||
data + std::max(offset - N, 0LL),
|
||||
std::min(N, total - offset));
|
||||
} else {
|
||||
mask_b_recv |= 1 << rank;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::swap(a, b);
|
||||
mask_a_send = mask_b_send;
|
||||
mask_a_recv = mask_b_recv;
|
||||
mask_b_send = mask_init;
|
||||
mask_b_recv = mask_init;
|
||||
offset += N;
|
||||
buff++;
|
||||
in_flight += 2 * num_peers;
|
||||
read_offset += N;
|
||||
}
|
||||
|
||||
{
|
||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
// Main loop
|
||||
//
|
||||
// Keep going until we have no longer data in flight.
|
||||
while (in_flight > 0) {
|
||||
// Poll the hardware for completions.
|
||||
//
|
||||
// If a send was completed mark how many completions we have received
|
||||
// for that buffer. If we have sent the buffer to all peers we can
|
||||
// reuse the buffer so copy the next chunk of data and send it to all.
|
||||
//
|
||||
// If a receive is completed then advance the pointer of completed
|
||||
// receives.
|
||||
ibv_wc wc[8];
|
||||
int n = cm_.poll(8, wc);
|
||||
for (int i = 0; i < n; i++) {
|
||||
int work_type = wc[i].wr_id >> 16;
|
||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||
int rank = wc[i].wr_id & 0xff;
|
||||
|
||||
if (work_type == SEND_WR) {
|
||||
mask_a_send |= 1 << rank;
|
||||
} else {
|
||||
mask_a_recv |= 1 << rank;
|
||||
reduce_op(
|
||||
cm_.buffer(rank, a).begin<T>(),
|
||||
data + offset - N,
|
||||
std::min(N, total + N - offset));
|
||||
in_flight--;
|
||||
|
||||
if (work_type == SEND_WR && read_offset < total) {
|
||||
completed_send_count[buff]++;
|
||||
if (completed_send_count[buff] == num_peers) {
|
||||
std::copy(
|
||||
data + read_offset,
|
||||
data + std::min(read_offset + N, total),
|
||||
cm_.send_buffer(buff).begin<T>());
|
||||
post_send_all(buff);
|
||||
|
||||
completed_send_count[buff] = 0;
|
||||
in_flight += num_peers;
|
||||
read_offset += N;
|
||||
}
|
||||
}
|
||||
|
||||
else if (work_type == RECV_WR) {
|
||||
completed_recv_end[rank]++;
|
||||
}
|
||||
}
|
||||
|
||||
// Process the completed recv
|
||||
//
|
||||
// For each rank we have a range of completed recv defined by a begin
|
||||
// and end inclusive and exlusive in standard C++ fashion.
|
||||
//
|
||||
// When there is an unprocessed receive we first check if we have
|
||||
// finished sending the write location. If so then we reduce in-place
|
||||
// and then check if there is more to be received and post a recv.
|
||||
for (int r = 0; r < size_; r++) {
|
||||
int s = completed_recv_begin[r];
|
||||
int e = completed_recv_end[r];
|
||||
int w = s * N;
|
||||
while (w < read_offset && e - s > 0) {
|
||||
int buff = s % PIPELINE;
|
||||
reduce_op(
|
||||
cm_.buffer(r, buff).begin<T>(),
|
||||
data + w,
|
||||
std::min(N, total - w));
|
||||
w += N;
|
||||
s++;
|
||||
if (w + (PIPELINE - 1) * N < total) {
|
||||
cm_.recv_from(r, buff);
|
||||
in_flight++;
|
||||
}
|
||||
}
|
||||
completed_recv_begin[r] = s;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user