diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index 1f1c1b0b6..5bf08200e 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -199,6 +199,7 @@ class SocketThread { } void worker() { + int error_count = 0; bool delete_recv = false; bool delete_send = false; while (true) { @@ -235,10 +236,11 @@ class SocketThread { task.buffer = static_cast(task.buffer) + r; task.size -= r; delete_recv = task.size == 0; + error_count = 0; } else if (errno != EAGAIN) { + error_count++; log_info( true, "Receiving from socket", fd_, "failed with errno", errno); - return; } } if (!sends_.empty()) { @@ -248,11 +250,17 @@ class SocketThread { task.buffer = static_cast(task.buffer) + r; task.size -= r; delete_send = task.size == 0; + error_count = 0; } else if (errno != EAGAIN) { + error_count++; log_info(true, "Sending to socket", fd_, "failed with errno", errno); - return; } } + + if (error_count >= 10) { + log_info(true, "Too many send/recv errors. Aborting..."); + return; + } } } diff --git a/mlx/version.h b/mlx/version.h index 35b026149..f244dcb16 100644 --- a/mlx/version.h +++ b/mlx/version.h @@ -3,7 +3,7 @@ #pragma once #define MLX_VERSION_MAJOR 0 -#define MLX_VERSION_MINOR 24 +#define MLX_VERSION_MINOR 23 #define MLX_VERSION_PATCH 2 #define MLX_VERSION_NUMERIC \ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 1a749beed..5d6bc4383 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -112,7 +112,12 @@ def extract_rings(hosts, index): break if not ring: break - rings.append(normalize(concretize(ring, used_ports))) + try: + rings.append(normalize(concretize(ring, used_ports))) + except RuntimeError: + if len(rings) > 0: + return rings + raise return rings diff --git a/setup.py b/setup.py index d4b5e15dd..72bc2dba3 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ import os import platform import re import subprocess -import sys from pathlib import Path from subprocess import run @@ -173,7 +172,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.23.1"), + version=get_version("0.23.2"), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.",