Add send/recv

This commit is contained in:
Angelos Katharopoulos
2025-10-29 14:09:25 -07:00
parent 2444fbdfe9
commit 45727b0c02

View File

@@ -816,9 +816,114 @@ class IBVGroup : public GroupImpl {
});
}
void send(const array& input, int dst, Stream stream) override {}
void send(const array& input, int dst, Stream stream) override {
auto data = input.data<char>();
int64_t n_bytes = input.nbytes();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(input);
encoder.dispatch([data, n_bytes, dst, this]() {
constexpr int PIPELINE = 2;
constexpr int WC_NUM = PIPELINE;
constexpr int N = BUFFER_SIZE;
void recv(array& out, int src, Stream stream) override {}
int in_flight = 0;
int64_t read_offset = 0;
// Prefill the pipeline
int buff = 0;
while (read_offset < n_bytes && buff < PIPELINE) {
std::copy(
data + read_offset,
data + std::min(read_offset + N, n_bytes),
cm_.send_buffer(buff).begin<char>());
cm_.send_to(dst, buff);
buff++;
read_offset += N;
in_flight++;
}
// Main loop
while (in_flight > 0) {
// Poll the hardware for completions.
//
// If a send was completed and we have more data to send then go ahead
// and send them.
ibv_wc wc[WC_NUM];
int n = cm_.poll(WC_NUM, wc);
for (int i = 0; i < n; i++) {
int buff = (wc[i].wr_id >> 8) & 0xff;
int rank = wc[i].wr_id & 0xff;
in_flight--;
if (read_offset < n_bytes) {
std::copy(
data + read_offset,
data + std::min(read_offset + N, n_bytes),
cm_.send_buffer(buff).begin<char>());
cm_.send_to(dst, buff);
read_offset += N;
in_flight++;
}
}
}
});
}
void recv(array& out, int src, Stream stream) override {
auto data = out.data<char>();
int64_t n_bytes = out.nbytes();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_output_array(out);
encoder.dispatch([data, n_bytes, src, this]() {
constexpr int PIPELINE = 2;
constexpr int WC_NUM = PIPELINE;
constexpr int N = BUFFER_SIZE;
int in_flight = 0;
int64_t write_offset = 0;
// Prefill the pipeline
int buff = 0;
while (write_offset < n_bytes && buff < PIPELINE) {
cm_.recv_from(src, buff);
in_flight++;
buff++;
}
// Main loop
while (in_flight > 0) {
// Poll the hardware for completions.
//
// If a recv was completed copy it to the output and if we have more
// data to fetch post another recv.
ibv_wc wc[WC_NUM];
int n = cm_.poll(WC_NUM, wc);
for (int i = 0; i < n; i++) {
int buff = (wc[i].wr_id >> 8) & 0xff;
int rank = wc[i].wr_id & 0xff;
in_flight--;
std::copy(
cm_.buffer(src, buff).begin<char>(),
cm_.buffer(src, buff).begin<char>() +
std::min(n_bytes - write_offset, static_cast<int64_t>(N)),
data + write_offset);
write_offset += N;
if (write_offset + (PIPELINE - 1) * N < n_bytes) {
cm_.recv_from(src, buff);
in_flight++;
}
}
}
});
}
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
throw std::runtime_error("[ibv] Group split not supported.");