diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index c1275737a..9679d9ff8 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -336,7 +336,7 @@ std::vector accept_connections( for (auto& address : addresses) { detail::TCPSocket socket(RING_TAG); socket.listen(RING_TAG, address); - sockets.push_back(socket.accept(RING_TAG)); + sockets.push_back(socket.accept(RING_TAG).detach()); } return sockets; @@ -354,21 +354,22 @@ std::vector make_connections( for (auto& address : addresses) { sockets.push_back(detail::TCPSocket::connect( - RING_TAG, - address, - CONN_ATTEMPTS, - CONN_WAIT, - [verbose](int attempt, int wait) { - log_info( - verbose, - "Attempt", - attempt, - "waiting", - wait, - "ms (error:", - errno, - ")"); - })); + RING_TAG, + address, + CONN_ATTEMPTS, + CONN_WAIT, + [verbose](int attempt, int wait) { + log_info( + verbose, + "Attempt", + attempt, + "waiting", + wait, + "ms (error:", + errno, + ")"); + }) + .detach()); } return sockets; diff --git a/mlx/distributed/utils.cpp b/mlx/distributed/utils.cpp index a91fa2488..1e7f50d2b 100644 --- a/mlx/distributed/utils.cpp +++ b/mlx/distributed/utils.cpp @@ -80,6 +80,12 @@ TCPSocket::~TCPSocket() { } } +int TCPSocket::detach() { + int s = sock_; + sock_ = -1; + return s; +} + void TCPSocket::listen(const char* tag, const address_t& addr) { int success; diff --git a/mlx/distributed/utils.h b/mlx/distributed/utils.h index ce0220931..cf229c5fc 100644 --- a/mlx/distributed/utils.h +++ b/mlx/distributed/utils.h @@ -43,6 +43,8 @@ class TCPSocket { void send(const char* tag, const void* data, size_t len); void recv(const char* tag, void* data, size_t len); + int detach(); + operator int() const { return sock_; }