mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix ring of 2 and allow scalars in API (#1906)
This commit is contained in:
parent
7d042f17fe
commit
6bf00ef631
@ -3,6 +3,7 @@
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
|
||||
@ -22,6 +23,10 @@
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
||||
#ifndef SOL_TCP
|
||||
#define SOL_TCP IPPROTO_TCP
|
||||
#endif
|
||||
|
||||
#define SWITCH_TYPE(x, ...) \
|
||||
switch ((x).dtype()) { \
|
||||
case bool_: { \
|
||||
@ -226,7 +231,7 @@ class SocketThread {
|
||||
if (!recvs_.empty()) {
|
||||
auto& task = recvs_.front();
|
||||
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
|
||||
if (r >= 0) {
|
||||
if (r > 0) {
|
||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||
task.size -= r;
|
||||
delete_recv = task.size == 0;
|
||||
@ -239,7 +244,7 @@ class SocketThread {
|
||||
if (!sends_.empty()) {
|
||||
auto& task = sends_.front();
|
||||
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
|
||||
if (r >= 0) {
|
||||
if (r > 0) {
|
||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||
task.size -= r;
|
||||
delete_send = task.size == 0;
|
||||
@ -560,6 +565,13 @@ class RingGroup : public GroupImpl {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Configure all sockets to use TCP no delay.
|
||||
int one = 1;
|
||||
for (int i = 0; i < sockets_right_.size(); i++) {
|
||||
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||
}
|
||||
|
||||
// Start the all reduce threads. One all reduce per direction per ring.
|
||||
pool_.resize(sockets_right_.size() + sockets_left_.size());
|
||||
|
||||
@ -624,12 +636,15 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void recv(array& out, int src) override {
|
||||
// NOTE: We 'll check the sockets with the opposite order of send so that
|
||||
// they work even with 2 nodes where left and right is the same
|
||||
// neighbor.
|
||||
int right = (rank_ + 1) % size_;
|
||||
int left = (rank_ + size_ - 1) % size_;
|
||||
if (src == right) {
|
||||
recv(sockets_right_, out.data<char>(), out.nbytes());
|
||||
} else if (src == left) {
|
||||
if (src == left) {
|
||||
recv(sockets_left_, out.data<char>(), out.nbytes());
|
||||
} else if (src == right) {
|
||||
recv(sockets_right_, out.data<char>(), out.nbytes());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Recv only supported from direct neighbors "
|
||||
@ -801,9 +816,12 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||
size_t segment_size = ceildiv(data_size, sockets.size());
|
||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||
std::vector<std::future<void>> sends;
|
||||
for (int i = 0; i < sockets.size(); i++) {
|
||||
if (i * segment_size >= data_size) {
|
||||
break;
|
||||
}
|
||||
sends.emplace_back(comm_.send(
|
||||
sockets[i],
|
||||
data + i * segment_size,
|
||||
@ -815,9 +833,12 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||
size_t segment_size = ceildiv(data_size, sockets.size());
|
||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||
std::vector<std::future<void>> recvs;
|
||||
for (int i = 0; i < sockets.size(); i++) {
|
||||
if (i * segment_size >= data_size) {
|
||||
break;
|
||||
}
|
||||
recvs.emplace_back(comm_.recv(
|
||||
sockets[i],
|
||||
data + i * segment_size,
|
||||
|
@ -10,6 +10,8 @@
|
||||
#include "mlx/distributed/distributed.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
|
||||
#include "python/src/utils.h"
|
||||
|
||||
namespace mx = mlx::core;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
@ -86,7 +88,11 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"all_sum",
|
||||
&mx::distributed::all_sum,
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::all_sum(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
@ -112,7 +118,11 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"all_gather",
|
||||
&mx::distributed::all_gather,
|
||||
[](const ScalarOrArray& x,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::all_gather(to_array(x), group, s);
|
||||
},
|
||||
"x"_a,
|
||||
nb::kw_only(),
|
||||
"group"_a = nb::none(),
|
||||
@ -139,7 +149,12 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"send",
|
||||
&mx::distributed::send,
|
||||
[](const ScalarOrArray& x,
|
||||
int dst,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::send(to_array(x), dst, group, s);
|
||||
},
|
||||
"x"_a,
|
||||
"dst"_a,
|
||||
nb::kw_only(),
|
||||
@ -195,7 +210,12 @@ void init_distributed(nb::module_& parent_module) {
|
||||
|
||||
m.def(
|
||||
"recv_like",
|
||||
&mx::distributed::recv_like,
|
||||
[](const ScalarOrArray& x,
|
||||
int src,
|
||||
std::optional<mx::distributed::Group> group,
|
||||
mx::StreamOrDevice s) {
|
||||
return mx::distributed::recv_like(to_array(x), src, group, s);
|
||||
},
|
||||
"x"_a,
|
||||
"src"_a,
|
||||
nb::kw_only(),
|
||||
|
Loading…
Reference in New Issue
Block a user