Fix ring of 2 and allow scalars in API (#1906)

This commit is contained in:
Angelos Katharopoulos 2025-02-25 17:03:01 -08:00 committed by GitHub
parent 7d042f17fe
commit 6bf00ef631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 11 deletions

View File

@ -3,6 +3,7 @@
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <unistd.h>
@ -22,6 +23,10 @@
#include "mlx/distributed/distributed_impl.h"
#include "mlx/threadpool.h"
#ifndef SOL_TCP
#define SOL_TCP IPPROTO_TCP
#endif
#define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \
case bool_: { \
@ -226,7 +231,7 @@ class SocketThread {
if (!recvs_.empty()) {
auto& task = recvs_.front();
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
if (r >= 0) {
if (r > 0) {
task.buffer = static_cast<char*>(task.buffer) + r;
task.size -= r;
delete_recv = task.size == 0;
@ -239,7 +244,7 @@ class SocketThread {
if (!sends_.empty()) {
auto& task = sends_.front();
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
if (r >= 0) {
if (r > 0) {
task.buffer = static_cast<char*>(task.buffer) + r;
task.size -= r;
delete_send = task.size == 0;
@ -560,6 +565,13 @@ class RingGroup : public GroupImpl {
throw std::invalid_argument(msg.str());
}
// Configure all sockets to use TCP no delay.
int one = 1;
for (int i = 0; i < sockets_right_.size(); i++) {
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
}
// Start the all reduce threads. One all reduce per direction per ring.
pool_.resize(sockets_right_.size() + sockets_left_.size());
@ -624,12 +636,15 @@ class RingGroup : public GroupImpl {
}
void recv(array& out, int src) override {
// NOTE: We 'll check the sockets with the opposite order of send so that
// they work even with 2 nodes where left and right is the same
// neighbor.
int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_;
if (src == right) {
recv(sockets_right_, out.data<char>(), out.nbytes());
} else if (src == left) {
if (src == left) {
recv(sockets_left_, out.data<char>(), out.nbytes());
} else if (src == right) {
recv(sockets_right_, out.data<char>(), out.nbytes());
} else {
std::ostringstream msg;
msg << "[ring] Recv only supported from direct neighbors "
@ -801,9 +816,12 @@ class RingGroup : public GroupImpl {
}
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
size_t segment_size = ceildiv(data_size, sockets.size());
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
std::vector<std::future<void>> sends;
for (int i = 0; i < sockets.size(); i++) {
if (i * segment_size >= data_size) {
break;
}
sends.emplace_back(comm_.send(
sockets[i],
data + i * segment_size,
@ -815,9 +833,12 @@ class RingGroup : public GroupImpl {
}
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
size_t segment_size = ceildiv(data_size, sockets.size());
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
std::vector<std::future<void>> recvs;
for (int i = 0; i < sockets.size(); i++) {
if (i * segment_size >= data_size) {
break;
}
recvs.emplace_back(comm_.recv(
sockets[i],
data + i * segment_size,

View File

@ -10,6 +10,8 @@
#include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h"
#include "python/src/utils.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;
@ -86,7 +88,11 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"all_sum",
&mx::distributed::all_sum,
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_sum(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
@ -112,7 +118,11 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"all_gather",
&mx::distributed::all_gather,
[](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_gather(to_array(x), group, s);
},
"x"_a,
nb::kw_only(),
"group"_a = nb::none(),
@ -139,7 +149,12 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"send",
&mx::distributed::send,
[](const ScalarOrArray& x,
int dst,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::send(to_array(x), dst, group, s);
},
"x"_a,
"dst"_a,
nb::kw_only(),
@ -195,7 +210,12 @@ void init_distributed(nb::module_& parent_module) {
m.def(
"recv_like",
&mx::distributed::recv_like,
[](const ScalarOrArray& x,
int src,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::recv_like(to_array(x), src, group, s);
},
"x"_a,
"src"_a,
nb::kw_only(),