mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Fix ring of 2 and allow scalars in API (#1906)
This commit is contained in:
parent
7d042f17fe
commit
6bf00ef631
@ -3,6 +3,7 @@
|
|||||||
#include <arpa/inet.h>
|
#include <arpa/inet.h>
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <netdb.h>
|
#include <netdb.h>
|
||||||
|
#include <netinet/tcp.h>
|
||||||
#include <sys/socket.h>
|
#include <sys/socket.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
|
||||||
@ -22,6 +23,10 @@
|
|||||||
#include "mlx/distributed/distributed_impl.h"
|
#include "mlx/distributed/distributed_impl.h"
|
||||||
#include "mlx/threadpool.h"
|
#include "mlx/threadpool.h"
|
||||||
|
|
||||||
|
#ifndef SOL_TCP
|
||||||
|
#define SOL_TCP IPPROTO_TCP
|
||||||
|
#endif
|
||||||
|
|
||||||
#define SWITCH_TYPE(x, ...) \
|
#define SWITCH_TYPE(x, ...) \
|
||||||
switch ((x).dtype()) { \
|
switch ((x).dtype()) { \
|
||||||
case bool_: { \
|
case bool_: { \
|
||||||
@ -226,7 +231,7 @@ class SocketThread {
|
|||||||
if (!recvs_.empty()) {
|
if (!recvs_.empty()) {
|
||||||
auto& task = recvs_.front();
|
auto& task = recvs_.front();
|
||||||
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
|
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
|
||||||
if (r >= 0) {
|
if (r > 0) {
|
||||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||||
task.size -= r;
|
task.size -= r;
|
||||||
delete_recv = task.size == 0;
|
delete_recv = task.size == 0;
|
||||||
@ -239,7 +244,7 @@ class SocketThread {
|
|||||||
if (!sends_.empty()) {
|
if (!sends_.empty()) {
|
||||||
auto& task = sends_.front();
|
auto& task = sends_.front();
|
||||||
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
|
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
|
||||||
if (r >= 0) {
|
if (r > 0) {
|
||||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||||
task.size -= r;
|
task.size -= r;
|
||||||
delete_send = task.size == 0;
|
delete_send = task.size == 0;
|
||||||
@ -560,6 +565,13 @@ class RingGroup : public GroupImpl {
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure all sockets to use TCP no delay.
|
||||||
|
int one = 1;
|
||||||
|
for (int i = 0; i < sockets_right_.size(); i++) {
|
||||||
|
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||||
|
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||||
|
}
|
||||||
|
|
||||||
// Start the all reduce threads. One all reduce per direction per ring.
|
// Start the all reduce threads. One all reduce per direction per ring.
|
||||||
pool_.resize(sockets_right_.size() + sockets_left_.size());
|
pool_.resize(sockets_right_.size() + sockets_left_.size());
|
||||||
|
|
||||||
@ -624,12 +636,15 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void recv(array& out, int src) override {
|
void recv(array& out, int src) override {
|
||||||
|
// NOTE: We 'll check the sockets with the opposite order of send so that
|
||||||
|
// they work even with 2 nodes where left and right is the same
|
||||||
|
// neighbor.
|
||||||
int right = (rank_ + 1) % size_;
|
int right = (rank_ + 1) % size_;
|
||||||
int left = (rank_ + size_ - 1) % size_;
|
int left = (rank_ + size_ - 1) % size_;
|
||||||
if (src == right) {
|
if (src == left) {
|
||||||
recv(sockets_right_, out.data<char>(), out.nbytes());
|
|
||||||
} else if (src == left) {
|
|
||||||
recv(sockets_left_, out.data<char>(), out.nbytes());
|
recv(sockets_left_, out.data<char>(), out.nbytes());
|
||||||
|
} else if (src == right) {
|
||||||
|
recv(sockets_right_, out.data<char>(), out.nbytes());
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[ring] Recv only supported from direct neighbors "
|
msg << "[ring] Recv only supported from direct neighbors "
|
||||||
@ -801,9 +816,12 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
|
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||||
size_t segment_size = ceildiv(data_size, sockets.size());
|
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||||
std::vector<std::future<void>> sends;
|
std::vector<std::future<void>> sends;
|
||||||
for (int i = 0; i < sockets.size(); i++) {
|
for (int i = 0; i < sockets.size(); i++) {
|
||||||
|
if (i * segment_size >= data_size) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
sends.emplace_back(comm_.send(
|
sends.emplace_back(comm_.send(
|
||||||
sockets[i],
|
sockets[i],
|
||||||
data + i * segment_size,
|
data + i * segment_size,
|
||||||
@ -815,9 +833,12 @@ class RingGroup : public GroupImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||||
size_t segment_size = ceildiv(data_size, sockets.size());
|
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||||
std::vector<std::future<void>> recvs;
|
std::vector<std::future<void>> recvs;
|
||||||
for (int i = 0; i < sockets.size(); i++) {
|
for (int i = 0; i < sockets.size(); i++) {
|
||||||
|
if (i * segment_size >= data_size) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
recvs.emplace_back(comm_.recv(
|
recvs.emplace_back(comm_.recv(
|
||||||
sockets[i],
|
sockets[i],
|
||||||
data + i * segment_size,
|
data + i * segment_size,
|
||||||
|
@ -10,6 +10,8 @@
|
|||||||
#include "mlx/distributed/distributed.h"
|
#include "mlx/distributed/distributed.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
|
|
||||||
|
#include "python/src/utils.h"
|
||||||
|
|
||||||
namespace mx = mlx::core;
|
namespace mx = mlx::core;
|
||||||
namespace nb = nanobind;
|
namespace nb = nanobind;
|
||||||
using namespace nb::literals;
|
using namespace nb::literals;
|
||||||
@ -86,7 +88,11 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"all_sum",
|
"all_sum",
|
||||||
&mx::distributed::all_sum,
|
[](const ScalarOrArray& x,
|
||||||
|
std::optional<mx::distributed::Group> group,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::distributed::all_sum(to_array(x), group, s);
|
||||||
|
},
|
||||||
"x"_a,
|
"x"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"group"_a = nb::none(),
|
"group"_a = nb::none(),
|
||||||
@ -112,7 +118,11 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"all_gather",
|
"all_gather",
|
||||||
&mx::distributed::all_gather,
|
[](const ScalarOrArray& x,
|
||||||
|
std::optional<mx::distributed::Group> group,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::distributed::all_gather(to_array(x), group, s);
|
||||||
|
},
|
||||||
"x"_a,
|
"x"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"group"_a = nb::none(),
|
"group"_a = nb::none(),
|
||||||
@ -139,7 +149,12 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"send",
|
"send",
|
||||||
&mx::distributed::send,
|
[](const ScalarOrArray& x,
|
||||||
|
int dst,
|
||||||
|
std::optional<mx::distributed::Group> group,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::distributed::send(to_array(x), dst, group, s);
|
||||||
|
},
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"dst"_a,
|
"dst"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
@ -195,7 +210,12 @@ void init_distributed(nb::module_& parent_module) {
|
|||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"recv_like",
|
"recv_like",
|
||||||
&mx::distributed::recv_like,
|
[](const ScalarOrArray& x,
|
||||||
|
int src,
|
||||||
|
std::optional<mx::distributed::Group> group,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::distributed::recv_like(to_array(x), src, group, s);
|
||||||
|
},
|
||||||
"x"_a,
|
"x"_a,
|
||||||
"src"_a,
|
"src"_a,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
|
Loading…
Reference in New Issue
Block a user