Fix ring of 2 and allow scalars in API (#1906)

This commit is contained in:
Angelos Katharopoulos 2025-02-25 17:03:01 -08:00 committed by GitHub
parent 7d042f17fe
commit 6bf00ef631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 11 deletions

View File

@ -3,6 +3,7 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <fcntl.h> #include <fcntl.h>
#include <netdb.h> #include <netdb.h>
#include <netinet/tcp.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <unistd.h> #include <unistd.h>
@ -22,6 +23,10 @@
#include "mlx/distributed/distributed_impl.h" #include "mlx/distributed/distributed_impl.h"
#include "mlx/threadpool.h" #include "mlx/threadpool.h"
#ifndef SOL_TCP
#define SOL_TCP IPPROTO_TCP
#endif
#define SWITCH_TYPE(x, ...) \ #define SWITCH_TYPE(x, ...) \
switch ((x).dtype()) { \ switch ((x).dtype()) { \
case bool_: { \ case bool_: { \
@ -226,7 +231,7 @@ class SocketThread {
if (!recvs_.empty()) { if (!recvs_.empty()) {
auto& task = recvs_.front(); auto& task = recvs_.front();
ssize_t r = ::recv(fd_, task.buffer, task.size, 0); ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
if (r >= 0) { if (r > 0) {
task.buffer = static_cast<char*>(task.buffer) + r; task.buffer = static_cast<char*>(task.buffer) + r;
task.size -= r; task.size -= r;
delete_recv = task.size == 0; delete_recv = task.size == 0;
@ -239,7 +244,7 @@ class SocketThread {
if (!sends_.empty()) { if (!sends_.empty()) {
auto& task = sends_.front(); auto& task = sends_.front();
ssize_t r = ::send(fd_, task.buffer, task.size, 0); ssize_t r = ::send(fd_, task.buffer, task.size, 0);
if (r >= 0) { if (r > 0) {
task.buffer = static_cast<char*>(task.buffer) + r; task.buffer = static_cast<char*>(task.buffer) + r;
task.size -= r; task.size -= r;
delete_send = task.size == 0; delete_send = task.size == 0;
@ -560,6 +565,13 @@ class RingGroup : public GroupImpl {
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// Configure all sockets to use TCP no delay.
int one = 1;
for (int i = 0; i < sockets_right_.size(); i++) {
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
}
// Start the all reduce threads. One all reduce per direction per ring. // Start the all reduce threads. One all reduce per direction per ring.
pool_.resize(sockets_right_.size() + sockets_left_.size()); pool_.resize(sockets_right_.size() + sockets_left_.size());
@ -624,12 +636,15 @@ class RingGroup : public GroupImpl {
} }
void recv(array& out, int src) override { void recv(array& out, int src) override {
// NOTE: We 'll check the sockets with the opposite order of send so that
// they work even with 2 nodes where left and right is the same
// neighbor.
int right = (rank_ + 1) % size_; int right = (rank_ + 1) % size_;
int left = (rank_ + size_ - 1) % size_; int left = (rank_ + size_ - 1) % size_;
if (src == right) { if (src == left) {
recv(sockets_right_, out.data<char>(), out.nbytes());
} else if (src == left) {
recv(sockets_left_, out.data<char>(), out.nbytes()); recv(sockets_left_, out.data<char>(), out.nbytes());
} else if (src == right) {
recv(sockets_right_, out.data<char>(), out.nbytes());
} else { } else {
std::ostringstream msg; std::ostringstream msg;
msg << "[ring] Recv only supported from direct neighbors " msg << "[ring] Recv only supported from direct neighbors "
@ -801,9 +816,12 @@ class RingGroup : public GroupImpl {
} }
void send(const std::vector<int>& sockets, char* data, size_t data_size) { void send(const std::vector<int>& sockets, char* data, size_t data_size) {
size_t segment_size = ceildiv(data_size, sockets.size()); size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
std::vector<std::future<void>> sends; std::vector<std::future<void>> sends;
for (int i = 0; i < sockets.size(); i++) { for (int i = 0; i < sockets.size(); i++) {
if (i * segment_size >= data_size) {
break;
}
sends.emplace_back(comm_.send( sends.emplace_back(comm_.send(
sockets[i], sockets[i],
data + i * segment_size, data + i * segment_size,
@ -815,9 +833,12 @@ class RingGroup : public GroupImpl {
} }
void recv(const std::vector<int>& sockets, char* data, size_t data_size) { void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
size_t segment_size = ceildiv(data_size, sockets.size()); size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
std::vector<std::future<void>> recvs; std::vector<std::future<void>> recvs;
for (int i = 0; i < sockets.size(); i++) { for (int i = 0; i < sockets.size(); i++) {
if (i * segment_size >= data_size) {
break;
}
recvs.emplace_back(comm_.recv( recvs.emplace_back(comm_.recv(
sockets[i], sockets[i],
data + i * segment_size, data + i * segment_size,

View File

@ -10,6 +10,8 @@
#include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed.h"
#include "mlx/distributed/ops.h" #include "mlx/distributed/ops.h"
#include "python/src/utils.h"
namespace mx = mlx::core; namespace mx = mlx::core;
namespace nb = nanobind; namespace nb = nanobind;
using namespace nb::literals; using namespace nb::literals;
@ -86,7 +88,11 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"all_sum", "all_sum",
&mx::distributed::all_sum, [](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_sum(to_array(x), group, s);
},
"x"_a, "x"_a,
nb::kw_only(), nb::kw_only(),
"group"_a = nb::none(), "group"_a = nb::none(),
@ -112,7 +118,11 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"all_gather", "all_gather",
&mx::distributed::all_gather, [](const ScalarOrArray& x,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::all_gather(to_array(x), group, s);
},
"x"_a, "x"_a,
nb::kw_only(), nb::kw_only(),
"group"_a = nb::none(), "group"_a = nb::none(),
@ -139,7 +149,12 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"send", "send",
&mx::distributed::send, [](const ScalarOrArray& x,
int dst,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::send(to_array(x), dst, group, s);
},
"x"_a, "x"_a,
"dst"_a, "dst"_a,
nb::kw_only(), nb::kw_only(),
@ -195,7 +210,12 @@ void init_distributed(nb::module_& parent_module) {
m.def( m.def(
"recv_like", "recv_like",
&mx::distributed::recv_like, [](const ScalarOrArray& x,
int src,
std::optional<mx::distributed::Group> group,
mx::StreamOrDevice s) {
return mx::distributed::recv_like(to_array(x), src, group, s);
},
"x"_a, "x"_a,
"src"_a, "src"_a,
nb::kw_only(), nb::kw_only(),