diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 2c1d27c6e..1a749beed 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -14,6 +14,8 @@ import time from collections import Counter from dataclasses import dataclass from pathlib import Path +from queue import Empty as QueueEmpty +from queue import Queue from select import select from subprocess import PIPE, Popen, run from typing import Optional @@ -185,46 +187,54 @@ def parse_hostlist(parser, hostlist, repeats): def make_monitor_script(rank, hostfile, cwd, env, command, verbose): + # Imports that are used throughout script = "" + script += "import os\n" + script += "import sys\n" + script += "import tempfile\n" + script += "from pathlib import Path\n" # Write the PID to a file so we can kill the process if needed - script += "pidfile=$(mktemp)\n" - script += "echo $$ >$pidfile\n" - script += "echo $pidfile\n" + script += "_, pidfile = tempfile.mkstemp() \n" + script += "open(pidfile, 'w').write(str(os.getpid()))\n" + script += "print(pidfile, flush=True)\n" # Change the working directory if one was requested. Otherwise attempt to - # change to change to the current one but don't fail if it wasn't possible. + # change to the current one but don't fail if it wasn't possible. d = cwd or os.getcwd() - script += f"if [ -d {shlex.quote(d)} ]; then\n" - script += f" cd {shlex.quote(d)}\n" + script += f"if Path({repr(d)}).exists():\n" + script += f" os.chdir({repr(d)})\n" if cwd is not None: - script += "else\n" - script += f" echo Failed to change directory to {shlex.quote(d)} 1>&2\n" - script += f" exit 1\n" - script += "fi\n" + script += "else:\n" + script += ( + f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n" + ) + script += f" sys.exit(1)\n" # Add the environment variables that were given to us + script += "env = dict(os.environ)\n" for e in env: key, *value = e.split("=", maxsplit=1) value = shlex.quote(value[0]) if len(value) > 0 else "" if not all(c.isalnum() or c == "_" for c in key): log_warning(f"'{e}' is an invalid environment variable so it is ignored") continue - script += f"export {key}={value}\n" + script += f"env[{repr(key)}] = {repr(value)}\n" # Add the environment variables to enable the ring distributed backend if hostfile != "": - script += "tmpfile=$(mktemp)\n" - script += f"echo {shlex.quote(hostfile)} >$tmpfile\n" + script += "_, hostfile = tempfile.mkstemp()\n" + script += "with open(hostfile, 'w') as f:\n" + script += f" f.write({repr(hostfile)})\n" if verbose: - script += "export MLX_RING_VERBOSE=1\n" - script += "export MLX_HOSTFILE=$tmpfile\n" - script += f"export MLX_RANK={rank}\n" + script += "env['MLX_RING_VERBOSE'] = '1'\n" + script += "env['MLX_HOSTFILE'] = hostfile\n" + script += f"env['MLX_RANK'] = '{rank}'\n" script += "\n" # Replace the process with the script - script += shlex.join(["exec", *command]) - script += "\n" + script += f"command = [{','.join(map(repr, command))}]\n" + script += "os.execve(command[0], command, env)\n" return script @@ -233,28 +243,37 @@ def launch_ring(parser, hosts, args, command): stop = False exit_codes = [None] * len(hosts) - def node_thread(rank, host, hostfile): + def node_thread(rank, host, hostfile, input_queue): is_local = host == "127.0.0.1" script = make_monitor_script( rank, hostfile, args.cwd, args.env, command, args.verbose ) script_b64 = base64.b64encode(script.encode()).decode() - cmd = f'echo "{script_b64}" | base64 -d | /bin/bash' + cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"' if not is_local: cmd = f"ssh {host} '{cmd}'" p = Popen( cmd, shell=True, + stdin=PIPE, stdout=PIPE, stderr=PIPE, ) os.set_blocking(p.stdout.fileno(), False) os.set_blocking(p.stderr.fileno(), False) + os.set_blocking(p.stdin.fileno(), False) # Repeat the stdout and stderr to the local machine + to_read = [p.stdout.fileno(), p.stderr.fileno()] + to_write = [p.stdin.fileno()] pidfile = "" + stdin_buffer = b"" while p.poll() is None: - rlist, _, _ = select([p.stdout.fileno(), p.stderr.fileno()], [], [], 1.0) + try: + stdin_buffer += input_queue.get_nowait() + except QueueEmpty: + pass + rlist, wlist, _ = select(to_read, to_write, [], 1.0) for fd in rlist: is_stdout = fd == p.stdout.fileno() outfile = sys.stdout if is_stdout else sys.stderr @@ -266,6 +285,11 @@ def launch_ring(parser, hosts, args, command): msg = msg[0] if msg else "" outfile.write(msg) + outfile.flush() + for fd in wlist: + if len(stdin_buffer) > 0: + n = os.write(fd, stdin_buffer) + stdin_buffer = stdin_buffer[n:] if stop: p.terminate() break @@ -310,16 +334,25 @@ def launch_ring(parser, hosts, args, command): log(args.verbose, "Running", shlex.join(command)) + input_queues = [] threads = [] for i, h in enumerate(hosts): if i + 1 == len(hosts): time.sleep(1.0) - t = threading.Thread(target=node_thread, args=(i, h.ssh_hostname, hostfile)) + input_queues.append(Queue()) + t = threading.Thread( + target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1]) + ) t.start() threads.append(t) + os.set_blocking(sys.stdin.fileno(), False) while not stop: - time.sleep(1.0) + rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0) + for fd in rlist: + stdin_buffer = os.read(fd, 8192) + for q in input_queues: + q.put(stdin_buffer) if any(t.is_alive() for t in threads): for i, t in enumerate(threads): if not t.is_alive(): @@ -730,6 +763,8 @@ def main(): if len(rest) == 0: parser.error("No script is provided") + if rest[0] == "--": + rest.pop(0) # Try to extract a list of hosts and corresponding ips if args.hostfile is not None: