Allow debugging in distributed mode (#1920)

This commit is contained in:
Angelos Katharopoulos 2025-03-04 13:01:10 -08:00 committed by GitHub
parent e613d0eaf0
commit a0737273d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,8 @@ import time
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
from typing import Optional
@ -185,46 +187,54 @@ def parse_hostlist(parser, hostlist, repeats):
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "pidfile=$(mktemp)\n"
script += "echo $$ >$pidfile\n"
script += "echo $pidfile\n"
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to change to the current one but don't fail if it wasn't possible.
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if [ -d {shlex.quote(d)} ]; then\n"
script += f" cd {shlex.quote(d)}\n"
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else\n"
script += f" echo Failed to change directory to {shlex.quote(d)} 1>&2\n"
script += f" exit 1\n"
script += "fi\n"
script += "else:\n"
script += (
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
)
script += f" sys.exit(1)\n"
# Add the environment variables that were given to us
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
continue
script += f"export {key}={value}\n"
script += f"env[{repr(key)}] = {repr(value)}\n"
# Add the environment variables to enable the ring distributed backend
if hostfile != "":
script += "tmpfile=$(mktemp)\n"
script += f"echo {shlex.quote(hostfile)} >$tmpfile\n"
script += "_, hostfile = tempfile.mkstemp()\n"
script += "with open(hostfile, 'w') as f:\n"
script += f" f.write({repr(hostfile)})\n"
if verbose:
script += "export MLX_RING_VERBOSE=1\n"
script += "export MLX_HOSTFILE=$tmpfile\n"
script += f"export MLX_RANK={rank}\n"
script += "env['MLX_RING_VERBOSE'] = '1'\n"
script += "env['MLX_HOSTFILE'] = hostfile\n"
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += shlex.join(["exec", *command])
script += "\n"
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
@ -233,28 +243,37 @@ def launch_ring(parser, hosts, args, command):
stop = False
exit_codes = [None] * len(hosts)
def node_thread(rank, host, hostfile):
def node_thread(rank, host, hostfile, input_queue):
is_local = host == "127.0.0.1"
script = make_monitor_script(
rank, hostfile, args.cwd, args.env, command, args.verbose
)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'echo "{script_b64}" | base64 -d | /bin/bash'
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
p = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
# Repeat the stdout and stderr to the local machine
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno()]
pidfile = ""
stdin_buffer = b""
while p.poll() is None:
rlist, _, _ = select([p.stdout.fileno(), p.stderr.fileno()], [], [], 1.0)
try:
stdin_buffer += input_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
outfile = sys.stdout if is_stdout else sys.stderr
@ -266,6 +285,11 @@ def launch_ring(parser, hosts, args, command):
msg = msg[0] if msg else ""
outfile.write(msg)
outfile.flush()
for fd in wlist:
if len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
if stop:
p.terminate()
break
@ -310,16 +334,25 @@ def launch_ring(parser, hosts, args, command):
log(args.verbose, "Running", shlex.join(command))
input_queues = []
threads = []
for i, h in enumerate(hosts):
if i + 1 == len(hosts):
time.sleep(1.0)
t = threading.Thread(target=node_thread, args=(i, h.ssh_hostname, hostfile))
input_queues.append(Queue())
t = threading.Thread(
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
while not stop:
time.sleep(1.0)
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in input_queues:
q.put(stdin_buffer)
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
@ -730,6 +763,8 @@ def main():
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None: