mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add working reduce and semi-working all gather
This commit is contained in:
@@ -70,9 +70,8 @@ struct Destination {
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& os, const Destination& dst) {
|
std::ostream& operator<<(std::ostream& os, const Destination& dst) {
|
||||||
os << dst.local_id << " " << dst.queue_pair_number
|
os << dst.local_id << " " << dst.queue_pair_number << " "
|
||||||
<< " " << dst.packet_sequence_number << " "
|
<< dst.packet_sequence_number << " " << dst.global_identifier;
|
||||||
<< dst.global_identifier;
|
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -378,9 +377,10 @@ class SideChannel {
|
|||||||
sockets_.push_back(server.accept(IBV_TAG));
|
sockets_.push_back(server.accept(IBV_TAG));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> ranks(size-1);
|
std::vector<int> ranks(size - 1);
|
||||||
for (int i = 0; i < size - 1; i++) {
|
for (int i = 0; i < size - 1; i++) {
|
||||||
sockets_[i].recv(IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));
|
sockets_[i].recv(
|
||||||
|
IBV_TAG, reinterpret_cast<char*>(&ranks[i]), sizeof(int));
|
||||||
ranks[i]--;
|
ranks[i]--;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < size - 1; i++) {
|
for (int i = 0; i < size - 1; i++) {
|
||||||
@@ -739,115 +739,75 @@ class IBVGroup : public GroupImpl {
|
|||||||
// Copy our data to the appropriate place
|
// Copy our data to the appropriate place
|
||||||
std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes);
|
std::memcpy(out_ptr + rank_ * n_bytes, in_ptr, n_bytes);
|
||||||
|
|
||||||
|
// Fully connected all gather
|
||||||
|
char* data = out_ptr;
|
||||||
|
char* our_data = out_ptr + rank_ * n_bytes;
|
||||||
constexpr int64_t N = BUFFER_SIZE;
|
constexpr int64_t N = BUFFER_SIZE;
|
||||||
|
constexpr int PIPELINE = 2;
|
||||||
int64_t total = static_cast<int64_t>(n_bytes);
|
int64_t total = static_cast<int64_t>(n_bytes);
|
||||||
int64_t offset = N;
|
int num_peers = size_ - 1;
|
||||||
int a = 0, b = 1;
|
|
||||||
|
|
||||||
int mask_init = 1 << rank_;
|
// Counters to maintain the state of transfers
|
||||||
int mask_target = (1 << size_) - 1;
|
int in_flight = 0;
|
||||||
|
int read_offset = 0;
|
||||||
|
int completed_send_count[PIPELINE] = {0};
|
||||||
|
int write_offset[MAX_PEERS] = {0};
|
||||||
|
|
||||||
post_recv_all(a);
|
// Prefill the pipeline
|
||||||
std::copy(
|
int buff = 0;
|
||||||
in_ptr,
|
while (read_offset < total && buff < PIPELINE) {
|
||||||
in_ptr + std::min(N, total),
|
post_recv_all(buff);
|
||||||
cm_.send_buffer(a).begin<char>());
|
|
||||||
post_send_all(a);
|
|
||||||
|
|
||||||
int mask_a_send = mask_init;
|
|
||||||
int mask_a_recv = mask_init;
|
|
||||||
int mask_b_send = mask_init;
|
|
||||||
int mask_b_recv = mask_init;
|
|
||||||
|
|
||||||
while (offset < total) {
|
|
||||||
// While the previous chunk is in flight copy to the next send buffer
|
|
||||||
std::copy(
|
std::copy(
|
||||||
in_ptr + offset,
|
our_data + read_offset,
|
||||||
in_ptr + std::min(offset + N, total),
|
our_data + std::min(read_offset + N, total),
|
||||||
cm_.send_buffer(b).begin<char>());
|
cm_.send_buffer(buff).begin<char>());
|
||||||
|
post_send_all(buff);
|
||||||
|
|
||||||
// Send if the previous send is already done
|
buff++;
|
||||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
in_flight += 2 * num_peers;
|
||||||
if (i == rank_) {
|
read_offset += N;
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (mask_a_send & m) {
|
|
||||||
cm_.send_to(i, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recv the next buffer if the previous one is already done
|
|
||||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
|
||||||
if (i == rank_) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (mask_a_recv & m) {
|
|
||||||
cm_.recv_from(i, b);
|
|
||||||
std::copy(
|
|
||||||
cm_.buffer(i, a).begin<char>(),
|
|
||||||
cm_.buffer(i, a).begin<char>() + std::min(N, total - offset),
|
|
||||||
out_ptr + i * n_bytes + std::max(offset - N, 0LL));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop until this chunk is all done
|
|
||||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
|
||||||
ibv_wc wc[8];
|
|
||||||
int n = cm_.poll(8, wc);
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
int work_type = wc[i].wr_id >> 16;
|
|
||||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
|
||||||
int rank = wc[i].wr_id & 0xff;
|
|
||||||
|
|
||||||
if (work_type == SEND_WR) {
|
|
||||||
if (buff == a) {
|
|
||||||
cm_.send_to(rank, b);
|
|
||||||
mask_a_send |= 1 << rank;
|
|
||||||
} else {
|
|
||||||
mask_b_send |= 1 << rank;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (buff == a) {
|
|
||||||
cm_.recv_from(rank, b);
|
|
||||||
mask_a_recv |= 1 << rank;
|
|
||||||
std::copy(
|
|
||||||
cm_.buffer(rank, a).begin<char>(),
|
|
||||||
cm_.buffer(rank, a).begin<char>() +
|
|
||||||
std::min(N, total - offset),
|
|
||||||
out_ptr + rank * n_bytes + std::max(offset - N, 0LL));
|
|
||||||
} else {
|
|
||||||
mask_b_recv |= 1 << rank;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::swap(a, b);
|
|
||||||
mask_a_send = mask_b_send;
|
|
||||||
mask_a_recv = mask_b_recv;
|
|
||||||
mask_b_send = mask_init;
|
|
||||||
mask_b_recv = mask_init;
|
|
||||||
offset += N;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
// Main loop
|
||||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
//
|
||||||
ibv_wc wc[8];
|
// Keep going until we have no longer data in flight.
|
||||||
int n = cm_.poll(8, wc);
|
while (in_flight > 0) {
|
||||||
for (int i = 0; i < n; i++) {
|
ibv_wc wc[8];
|
||||||
int work_type = wc[i].wr_id >> 16;
|
int n = cm_.poll(8, wc);
|
||||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
for (int i = 0; i < n; i++) {
|
||||||
int rank = wc[i].wr_id & 0xff;
|
int work_type = wc[i].wr_id >> 16;
|
||||||
|
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||||
|
int rank = wc[i].wr_id & 0xff;
|
||||||
|
|
||||||
if (work_type == SEND_WR) {
|
in_flight--;
|
||||||
mask_a_send |= 1 << rank;
|
|
||||||
} else {
|
// Send completed. If all sends completed then send the next chunk.
|
||||||
mask_a_recv |= 1 << rank;
|
if (work_type == SEND_WR && read_offset < total) {
|
||||||
|
completed_send_count[buff]++;
|
||||||
|
if (completed_send_count[buff] == num_peers) {
|
||||||
std::copy(
|
std::copy(
|
||||||
cm_.buffer(rank, a).begin<char>(),
|
our_data + read_offset,
|
||||||
cm_.buffer(rank, a).begin<char>() +
|
our_data + std::min(read_offset + N, total),
|
||||||
std::min(N, total + N - offset),
|
cm_.send_buffer(buff).begin<char>());
|
||||||
out_ptr + rank * n_bytes + offset - N);
|
post_send_all(buff);
|
||||||
|
|
||||||
|
completed_send_count[buff] = 0;
|
||||||
|
in_flight += num_peers;
|
||||||
|
read_offset += N;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recv completed. If we have more chunks then post another recv.
|
||||||
|
else if (work_type == RECV_WR) {
|
||||||
|
std::copy(
|
||||||
|
cm_.buffer(rank, buff).begin<char>(),
|
||||||
|
cm_.buffer(rank, buff).begin<char>() +
|
||||||
|
std::min(N, total - write_offset[rank]),
|
||||||
|
data + rank * n_bytes + write_offset[rank]);
|
||||||
|
write_offset[rank] += N;
|
||||||
|
if (write_offset[rank] + N * (PIPELINE - 1) < total) {
|
||||||
|
cm_.recv_from(rank, buff);
|
||||||
|
in_flight++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -902,112 +862,99 @@ class IBVGroup : public GroupImpl {
|
|||||||
// Fully connected all reduce
|
// Fully connected all reduce
|
||||||
T* data = out_ptr;
|
T* data = out_ptr;
|
||||||
constexpr int64_t N = BUFFER_SIZE / sizeof(T);
|
constexpr int64_t N = BUFFER_SIZE / sizeof(T);
|
||||||
|
constexpr int PIPELINE = 2;
|
||||||
int64_t total = static_cast<int64_t>(size);
|
int64_t total = static_cast<int64_t>(size);
|
||||||
int64_t offset = N;
|
int num_peers = size_ - 1;
|
||||||
int a = 0, b = 1;
|
|
||||||
|
|
||||||
int mask_init = 1 << rank_;
|
// Counters to maintain the state of transfers
|
||||||
int mask_target = (1 << size_) - 1;
|
int in_flight = 0;
|
||||||
|
int read_offset = 0;
|
||||||
|
int completed_send_count[PIPELINE] = {0};
|
||||||
|
int completed_recv_begin[MAX_PEERS] = {0};
|
||||||
|
int completed_recv_end[MAX_PEERS] = {0};
|
||||||
|
|
||||||
// Handle the first piece of data.
|
// Prefill the pipeline
|
||||||
post_recv_all(a);
|
int buff = 0;
|
||||||
std::copy(data, data + std::min(N, total), cm_.send_buffer(a).begin<T>());
|
while (read_offset < total && buff < PIPELINE) {
|
||||||
post_send_all(a);
|
post_recv_all(buff);
|
||||||
|
|
||||||
int mask_a_send = mask_init;
|
|
||||||
int mask_a_recv = mask_init;
|
|
||||||
int mask_b_send = mask_init;
|
|
||||||
int mask_b_recv = mask_init;
|
|
||||||
|
|
||||||
while (offset < total) {
|
|
||||||
// While the previous chunk is in flight copy to the next send buffer
|
|
||||||
std::copy(
|
std::copy(
|
||||||
data + offset,
|
data + read_offset,
|
||||||
data + std::min(offset + N, total),
|
data + std::min(read_offset + N, total),
|
||||||
cm_.send_buffer(b).begin<T>());
|
cm_.send_buffer(buff).begin<T>());
|
||||||
|
post_send_all(buff);
|
||||||
|
|
||||||
// Send if the previous send is already done
|
buff++;
|
||||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
in_flight += 2 * num_peers;
|
||||||
if (i == rank_) {
|
read_offset += N;
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (mask_a_send & m) {
|
|
||||||
cm_.send_to(i, b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recv the next buffer if the previous one is already done
|
|
||||||
for (int i = 0, m = 1; i < size_; i++, m *= 2) {
|
|
||||||
if (i == rank_) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (mask_a_recv & m) {
|
|
||||||
cm_.recv_from(i, b);
|
|
||||||
reduce_op(
|
|
||||||
cm_.buffer(i, a).begin<T>(),
|
|
||||||
data + std::max(offset - N, 0LL),
|
|
||||||
std::min(N, total - offset));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop until this chunk is all done
|
|
||||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
|
||||||
ibv_wc wc[8];
|
|
||||||
int n = cm_.poll(8, wc);
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
int work_type = wc[i].wr_id >> 16;
|
|
||||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
|
||||||
int rank = wc[i].wr_id & 0xff;
|
|
||||||
|
|
||||||
if (work_type == SEND_WR) {
|
|
||||||
if (buff == a) {
|
|
||||||
cm_.send_to(rank, b);
|
|
||||||
mask_a_send |= 1 << rank;
|
|
||||||
} else {
|
|
||||||
mask_b_send |= 1 << rank;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (buff == a) {
|
|
||||||
cm_.recv_from(rank, b);
|
|
||||||
mask_a_recv |= 1 << rank;
|
|
||||||
reduce_op(
|
|
||||||
cm_.buffer(rank, a).begin<T>(),
|
|
||||||
data + std::max(offset - N, 0LL),
|
|
||||||
std::min(N, total - offset));
|
|
||||||
} else {
|
|
||||||
mask_b_recv |= 1 << rank;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::swap(a, b);
|
|
||||||
mask_a_send = mask_b_send;
|
|
||||||
mask_a_recv = mask_b_recv;
|
|
||||||
mask_b_send = mask_init;
|
|
||||||
mask_b_recv = mask_init;
|
|
||||||
offset += N;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
// Main loop
|
||||||
while (mask_a_send != mask_target || mask_a_recv != mask_target) {
|
//
|
||||||
ibv_wc wc[8];
|
// Keep going until we have no longer data in flight.
|
||||||
int n = cm_.poll(8, wc);
|
while (in_flight > 0) {
|
||||||
for (int i = 0; i < n; i++) {
|
// Poll the hardware for completions.
|
||||||
int work_type = wc[i].wr_id >> 16;
|
//
|
||||||
int buff = (wc[i].wr_id >> 8) & 0xff;
|
// If a send was completed mark how many completions we have received
|
||||||
int rank = wc[i].wr_id & 0xff;
|
// for that buffer. If we have sent the buffer to all peers we can
|
||||||
|
// reuse the buffer so copy the next chunk of data and send it to all.
|
||||||
|
//
|
||||||
|
// If a receive is completed then advance the pointer of completed
|
||||||
|
// receives.
|
||||||
|
ibv_wc wc[8];
|
||||||
|
int n = cm_.poll(8, wc);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
int work_type = wc[i].wr_id >> 16;
|
||||||
|
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||||
|
int rank = wc[i].wr_id & 0xff;
|
||||||
|
|
||||||
if (work_type == SEND_WR) {
|
in_flight--;
|
||||||
mask_a_send |= 1 << rank;
|
|
||||||
} else {
|
if (work_type == SEND_WR && read_offset < total) {
|
||||||
mask_a_recv |= 1 << rank;
|
completed_send_count[buff]++;
|
||||||
reduce_op(
|
if (completed_send_count[buff] == num_peers) {
|
||||||
cm_.buffer(rank, a).begin<T>(),
|
std::copy(
|
||||||
data + offset - N,
|
data + read_offset,
|
||||||
std::min(N, total + N - offset));
|
data + std::min(read_offset + N, total),
|
||||||
|
cm_.send_buffer(buff).begin<T>());
|
||||||
|
post_send_all(buff);
|
||||||
|
|
||||||
|
completed_send_count[buff] = 0;
|
||||||
|
in_flight += num_peers;
|
||||||
|
read_offset += N;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else if (work_type == RECV_WR) {
|
||||||
|
completed_recv_end[rank]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the completed recv
|
||||||
|
//
|
||||||
|
// For each rank we have a range of completed recv defined by a begin
|
||||||
|
// and end inclusive and exlusive in standard C++ fashion.
|
||||||
|
//
|
||||||
|
// When there is an unprocessed receive we first check if we have
|
||||||
|
// finished sending the write location. If so then we reduce in-place
|
||||||
|
// and then check if there is more to be received and post a recv.
|
||||||
|
for (int r = 0; r < size_; r++) {
|
||||||
|
int s = completed_recv_begin[r];
|
||||||
|
int e = completed_recv_end[r];
|
||||||
|
int w = s * N;
|
||||||
|
while (w < read_offset && e - s > 0) {
|
||||||
|
int buff = s % PIPELINE;
|
||||||
|
reduce_op(
|
||||||
|
cm_.buffer(r, buff).begin<T>(),
|
||||||
|
data + w,
|
||||||
|
std::min(N, total - w));
|
||||||
|
w += N;
|
||||||
|
s++;
|
||||||
|
if (w + (PIPELINE - 1) * N < total) {
|
||||||
|
cm_.recv_from(r, buff);
|
||||||
|
in_flight++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
completed_recv_begin[r] = s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user