mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add send/recv
This commit is contained in:
@@ -816,9 +816,114 @@ class IBVGroup : public GroupImpl {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void send(const array& input, int dst, Stream stream) override {}
|
void send(const array& input, int dst, Stream stream) override {
|
||||||
|
auto data = input.data<char>();
|
||||||
|
int64_t n_bytes = input.nbytes();
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_input_array(input);
|
||||||
|
encoder.dispatch([data, n_bytes, dst, this]() {
|
||||||
|
constexpr int PIPELINE = 2;
|
||||||
|
constexpr int WC_NUM = PIPELINE;
|
||||||
|
constexpr int N = BUFFER_SIZE;
|
||||||
|
|
||||||
void recv(array& out, int src, Stream stream) override {}
|
int in_flight = 0;
|
||||||
|
int64_t read_offset = 0;
|
||||||
|
|
||||||
|
// Prefill the pipeline
|
||||||
|
int buff = 0;
|
||||||
|
while (read_offset < n_bytes && buff < PIPELINE) {
|
||||||
|
std::copy(
|
||||||
|
data + read_offset,
|
||||||
|
data + std::min(read_offset + N, n_bytes),
|
||||||
|
cm_.send_buffer(buff).begin<char>());
|
||||||
|
cm_.send_to(dst, buff);
|
||||||
|
|
||||||
|
buff++;
|
||||||
|
read_offset += N;
|
||||||
|
in_flight++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main loop
|
||||||
|
while (in_flight > 0) {
|
||||||
|
// Poll the hardware for completions.
|
||||||
|
//
|
||||||
|
// If a send was completed and we have more data to send then go ahead
|
||||||
|
// and send them.
|
||||||
|
ibv_wc wc[WC_NUM];
|
||||||
|
int n = cm_.poll(WC_NUM, wc);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||||
|
int rank = wc[i].wr_id & 0xff;
|
||||||
|
|
||||||
|
in_flight--;
|
||||||
|
|
||||||
|
if (read_offset < n_bytes) {
|
||||||
|
std::copy(
|
||||||
|
data + read_offset,
|
||||||
|
data + std::min(read_offset + N, n_bytes),
|
||||||
|
cm_.send_buffer(buff).begin<char>());
|
||||||
|
cm_.send_to(dst, buff);
|
||||||
|
|
||||||
|
read_offset += N;
|
||||||
|
in_flight++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void recv(array& out, int src, Stream stream) override {
|
||||||
|
auto data = out.data<char>();
|
||||||
|
int64_t n_bytes = out.nbytes();
|
||||||
|
auto& encoder = cpu::get_command_encoder(stream);
|
||||||
|
encoder.set_output_array(out);
|
||||||
|
encoder.dispatch([data, n_bytes, src, this]() {
|
||||||
|
constexpr int PIPELINE = 2;
|
||||||
|
constexpr int WC_NUM = PIPELINE;
|
||||||
|
constexpr int N = BUFFER_SIZE;
|
||||||
|
|
||||||
|
int in_flight = 0;
|
||||||
|
int64_t write_offset = 0;
|
||||||
|
|
||||||
|
// Prefill the pipeline
|
||||||
|
int buff = 0;
|
||||||
|
while (write_offset < n_bytes && buff < PIPELINE) {
|
||||||
|
cm_.recv_from(src, buff);
|
||||||
|
|
||||||
|
in_flight++;
|
||||||
|
buff++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main loop
|
||||||
|
while (in_flight > 0) {
|
||||||
|
// Poll the hardware for completions.
|
||||||
|
//
|
||||||
|
// If a recv was completed copy it to the output and if we have more
|
||||||
|
// data to fetch post another recv.
|
||||||
|
ibv_wc wc[WC_NUM];
|
||||||
|
int n = cm_.poll(WC_NUM, wc);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
int buff = (wc[i].wr_id >> 8) & 0xff;
|
||||||
|
int rank = wc[i].wr_id & 0xff;
|
||||||
|
|
||||||
|
in_flight--;
|
||||||
|
|
||||||
|
std::copy(
|
||||||
|
cm_.buffer(src, buff).begin<char>(),
|
||||||
|
cm_.buffer(src, buff).begin<char>() +
|
||||||
|
std::min(n_bytes - write_offset, static_cast<int64_t>(N)),
|
||||||
|
data + write_offset);
|
||||||
|
write_offset += N;
|
||||||
|
|
||||||
|
if (write_offset + (PIPELINE - 1) * N < n_bytes) {
|
||||||
|
cm_.recv_from(src, buff);
|
||||||
|
|
||||||
|
in_flight++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
|
||||||
throw std::runtime_error("[ibv] Group split not supported.");
|
throw std::runtime_error("[ibv] Group split not supported.");
|
||||||
|
|||||||
Reference in New Issue
Block a user