diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 2878abf9c..1f1c1b0b6 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -22,6 +23,10 @@ #include "mlx/distributed/distributed_impl.h" #include "mlx/threadpool.h" +#ifndef SOL_TCP +#define SOL_TCP IPPROTO_TCP +#endif + #define SWITCH_TYPE(x, ...) \ switch ((x).dtype()) { \ case bool_: { \ @@ -226,7 +231,7 @@ class SocketThread { if (!recvs_.empty()) { auto& task = recvs_.front(); ssize_t r = ::recv(fd_, task.buffer, task.size, 0); - if (r >= 0) { + if (r > 0) { task.buffer = static_cast(task.buffer) + r; task.size -= r; delete_recv = task.size == 0; @@ -239,7 +244,7 @@ class SocketThread { if (!sends_.empty()) { auto& task = sends_.front(); ssize_t r = ::send(fd_, task.buffer, task.size, 0); - if (r >= 0) { + if (r > 0) { task.buffer = static_cast(task.buffer) + r; task.size -= r; delete_send = task.size == 0; @@ -560,6 +565,13 @@ class RingGroup : public GroupImpl { throw std::invalid_argument(msg.str()); } + // Configure all sockets to use TCP no delay. + int one = 1; + for (int i = 0; i < sockets_right_.size(); i++) { + setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); + setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); + } + // Start the all reduce threads. One all reduce per direction per ring. pool_.resize(sockets_right_.size() + sockets_left_.size()); @@ -624,12 +636,15 @@ class RingGroup : public GroupImpl { } void recv(array& out, int src) override { + // NOTE: We 'll check the sockets with the opposite order of send so that + // they work even with 2 nodes where left and right is the same + // neighbor. int right = (rank_ + 1) % size_; int left = (rank_ + size_ - 1) % size_; - if (src == right) { - recv(sockets_right_, out.data(), out.nbytes()); - } else if (src == left) { + if (src == left) { recv(sockets_left_, out.data(), out.nbytes()); + } else if (src == right) { + recv(sockets_right_, out.data(), out.nbytes()); } else { std::ostringstream msg; msg << "[ring] Recv only supported from direct neighbors " @@ -801,9 +816,12 @@ class RingGroup : public GroupImpl { } void send(const std::vector& sockets, char* data, size_t data_size) { - size_t segment_size = ceildiv(data_size, sockets.size()); + size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); std::vector> sends; for (int i = 0; i < sockets.size(); i++) { + if (i * segment_size >= data_size) { + break; + } sends.emplace_back(comm_.send( sockets[i], data + i * segment_size, @@ -815,9 +833,12 @@ class RingGroup : public GroupImpl { } void recv(const std::vector& sockets, char* data, size_t data_size) { - size_t segment_size = ceildiv(data_size, sockets.size()); + size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size())); std::vector> recvs; for (int i = 0; i < sockets.size(); i++) { + if (i * segment_size >= data_size) { + break; + } recvs.emplace_back(comm_.recv( sockets[i], data + i * segment_size, diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index f3df1904d..b4b6550af 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -10,6 +10,8 @@ #include "mlx/distributed/distributed.h" #include "mlx/distributed/ops.h" +#include "python/src/utils.h" + namespace mx = mlx::core; namespace nb = nanobind; using namespace nb::literals; @@ -86,7 +88,11 @@ void init_distributed(nb::module_& parent_module) { m.def( "all_sum", - &mx::distributed::all_sum, + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::all_sum(to_array(x), group, s); + }, "x"_a, nb::kw_only(), "group"_a = nb::none(), @@ -112,7 +118,11 @@ void init_distributed(nb::module_& parent_module) { m.def( "all_gather", - &mx::distributed::all_gather, + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::all_gather(to_array(x), group, s); + }, "x"_a, nb::kw_only(), "group"_a = nb::none(), @@ -139,7 +149,12 @@ void init_distributed(nb::module_& parent_module) { m.def( "send", - &mx::distributed::send, + [](const ScalarOrArray& x, + int dst, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::send(to_array(x), dst, group, s); + }, "x"_a, "dst"_a, nb::kw_only(), @@ -195,7 +210,12 @@ void init_distributed(nb::module_& parent_module) { m.def( "recv_like", - &mx::distributed::recv_like, + [](const ScalarOrArray& x, + int src, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::recv_like(to_array(x), src, group, s); + }, "x"_a, "src"_a, nb::kw_only(),