mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-17 17:28:10 +08:00
Fix ring of 2 and allow scalars in API (#1906)
This commit is contained in:

committed by
GitHub

parent
7d042f17fe
commit
6bf00ef631
@@ -3,6 +3,7 @@
|
||||
#include <arpa/inet.h>
|
||||
#include <fcntl.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/socket.h>
|
||||
#include <unistd.h>
|
||||
|
||||
@@ -22,6 +23,10 @@
|
||||
#include "mlx/distributed/distributed_impl.h"
|
||||
#include "mlx/threadpool.h"
|
||||
|
||||
#ifndef SOL_TCP
|
||||
#define SOL_TCP IPPROTO_TCP
|
||||
#endif
|
||||
|
||||
#define SWITCH_TYPE(x, ...) \
|
||||
switch ((x).dtype()) { \
|
||||
case bool_: { \
|
||||
@@ -226,7 +231,7 @@ class SocketThread {
|
||||
if (!recvs_.empty()) {
|
||||
auto& task = recvs_.front();
|
||||
ssize_t r = ::recv(fd_, task.buffer, task.size, 0);
|
||||
if (r >= 0) {
|
||||
if (r > 0) {
|
||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||
task.size -= r;
|
||||
delete_recv = task.size == 0;
|
||||
@@ -239,7 +244,7 @@ class SocketThread {
|
||||
if (!sends_.empty()) {
|
||||
auto& task = sends_.front();
|
||||
ssize_t r = ::send(fd_, task.buffer, task.size, 0);
|
||||
if (r >= 0) {
|
||||
if (r > 0) {
|
||||
task.buffer = static_cast<char*>(task.buffer) + r;
|
||||
task.size -= r;
|
||||
delete_send = task.size == 0;
|
||||
@@ -560,6 +565,13 @@ class RingGroup : public GroupImpl {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Configure all sockets to use TCP no delay.
|
||||
int one = 1;
|
||||
for (int i = 0; i < sockets_right_.size(); i++) {
|
||||
setsockopt(sockets_right_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||
setsockopt(sockets_left_[i], SOL_TCP, TCP_NODELAY, &one, sizeof(one));
|
||||
}
|
||||
|
||||
// Start the all reduce threads. One all reduce per direction per ring.
|
||||
pool_.resize(sockets_right_.size() + sockets_left_.size());
|
||||
|
||||
@@ -624,12 +636,15 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void recv(array& out, int src) override {
|
||||
// NOTE: We 'll check the sockets with the opposite order of send so that
|
||||
// they work even with 2 nodes where left and right is the same
|
||||
// neighbor.
|
||||
int right = (rank_ + 1) % size_;
|
||||
int left = (rank_ + size_ - 1) % size_;
|
||||
if (src == right) {
|
||||
recv(sockets_right_, out.data<char>(), out.nbytes());
|
||||
} else if (src == left) {
|
||||
if (src == left) {
|
||||
recv(sockets_left_, out.data<char>(), out.nbytes());
|
||||
} else if (src == right) {
|
||||
recv(sockets_right_, out.data<char>(), out.nbytes());
|
||||
} else {
|
||||
std::ostringstream msg;
|
||||
msg << "[ring] Recv only supported from direct neighbors "
|
||||
@@ -801,9 +816,12 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void send(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||
size_t segment_size = ceildiv(data_size, sockets.size());
|
||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||
std::vector<std::future<void>> sends;
|
||||
for (int i = 0; i < sockets.size(); i++) {
|
||||
if (i * segment_size >= data_size) {
|
||||
break;
|
||||
}
|
||||
sends.emplace_back(comm_.send(
|
||||
sockets[i],
|
||||
data + i * segment_size,
|
||||
@@ -815,9 +833,12 @@ class RingGroup : public GroupImpl {
|
||||
}
|
||||
|
||||
void recv(const std::vector<int>& sockets, char* data, size_t data_size) {
|
||||
size_t segment_size = ceildiv(data_size, sockets.size());
|
||||
size_t segment_size = std::max(1024UL, ceildiv(data_size, sockets.size()));
|
||||
std::vector<std::future<void>> recvs;
|
||||
for (int i = 0; i < sockets.size(); i++) {
|
||||
if (i * segment_size >= data_size) {
|
||||
break;
|
||||
}
|
||||
recvs.emplace_back(comm_.recv(
|
||||
sockets[i],
|
||||
data + i * segment_size,
|
||||
|
Reference in New Issue
Block a user