diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 9c946005b..404ecc349 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -270,9 +270,11 @@ def launch_ring(parser, hosts, args, command): # Repeat the stdout and stderr to the local machine to_read = [p.stdout.fileno(), p.stderr.fileno()] - to_write = [p.stdin.fileno()] + to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()] pidfile = "" stdin_buffer = b"" + stdout_buffer = b"" + stderr_buffer = b"" while p.poll() is None: try: stdin_buffer += input_queue.get_nowait() @@ -280,8 +282,6 @@ def launch_ring(parser, hosts, args, command): pass rlist, wlist, _ = select(to_read, to_write, [], 1.0) for fd in rlist: - is_stdout = fd == p.stdout.fileno() - outfile = sys.stdout if is_stdout else sys.stderr msg = os.read(fd, 8192).decode(errors="ignore") # Fetch the PID file first if we haven't already @@ -289,12 +289,21 @@ def launch_ring(parser, hosts, args, command): pidfile, *msg = msg.split("\n", maxsplit=1) msg = msg[0] if msg else "" - outfile.write(msg) - outfile.flush() + is_stdout = fd == p.stdout.fileno() + if is_stdout: + stdout_buffer += msg.encode() + else: + stderr_buffer += msg.encode() for fd in wlist: - if len(stdin_buffer) > 0: + if fd == p.stdin.fileno() and len(stdin_buffer) > 0: n = os.write(fd, stdin_buffer) stdin_buffer = stdin_buffer[n:] + elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0: + n = os.write(fd, stdout_buffer) + stdout_buffer = stdout_buffer[n:] + elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0: + n = os.write(fd, stderr_buffer) + stderr_buffer = stderr_buffer[n:] if stop: p.terminate() break