From 0792ff02ff863106cf22076862909abc1f611c31 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 5 Mar 2025 13:16:19 -0800 Subject: [PATCH] Only fail when 10 consecutive socket errors occur (#1928) --- mlx/distributed/ring/ring.cpp | 12 ++++++++++-- python/mlx/distributed_run.py | 7 ++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 1f1c1b0b6..5bf08200e 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -199,6 +199,7 @@ class SocketThread { } void worker() { + int error_count = 0; bool delete_recv = false; bool delete_send = false; while (true) { @@ -235,10 +236,11 @@ class SocketThread { task.buffer = static_cast(task.buffer) + r; task.size -= r; delete_recv = task.size == 0; + error_count = 0; } else if (errno != EAGAIN) { + error_count++; log_info( true, "Receiving from socket", fd_, "failed with errno", errno); - return; } } if (!sends_.empty()) { @@ -248,11 +250,17 @@ class SocketThread { task.buffer = static_cast(task.buffer) + r; task.size -= r; delete_send = task.size == 0; + error_count = 0; } else if (errno != EAGAIN) { + error_count++; log_info(true, "Sending to socket", fd_, "failed with errno", errno); - return; } } + + if (error_count >= 10) { + log_info(true, "Too many send/recv errors. Aborting..."); + return; + } } } diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 1a749beed..5d6bc4383 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -112,7 +112,12 @@ def extract_rings(hosts, index): break if not ring: break - rings.append(normalize(concretize(ring, used_ports))) + try: + rings.append(normalize(concretize(ring, used_ports))) + except RuntimeError: + if len(rings) > 0: + return rings + raise return rings