mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Allow debugging in distributed mode (#1920)
This commit is contained in:
parent
e613d0eaf0
commit
a0737273d3
@ -14,6 +14,8 @@ import time
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from queue import Empty as QueueEmpty
|
||||||
|
from queue import Queue
|
||||||
from select import select
|
from select import select
|
||||||
from subprocess import PIPE, Popen, run
|
from subprocess import PIPE, Popen, run
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -185,46 +187,54 @@ def parse_hostlist(parser, hostlist, repeats):
|
|||||||
|
|
||||||
|
|
||||||
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
|
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
|
||||||
|
# Imports that are used throughout
|
||||||
script = ""
|
script = ""
|
||||||
|
script += "import os\n"
|
||||||
|
script += "import sys\n"
|
||||||
|
script += "import tempfile\n"
|
||||||
|
script += "from pathlib import Path\n"
|
||||||
|
|
||||||
# Write the PID to a file so we can kill the process if needed
|
# Write the PID to a file so we can kill the process if needed
|
||||||
script += "pidfile=$(mktemp)\n"
|
script += "_, pidfile = tempfile.mkstemp() \n"
|
||||||
script += "echo $$ >$pidfile\n"
|
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
||||||
script += "echo $pidfile\n"
|
script += "print(pidfile, flush=True)\n"
|
||||||
|
|
||||||
# Change the working directory if one was requested. Otherwise attempt to
|
# Change the working directory if one was requested. Otherwise attempt to
|
||||||
# change to change to the current one but don't fail if it wasn't possible.
|
# change to the current one but don't fail if it wasn't possible.
|
||||||
d = cwd or os.getcwd()
|
d = cwd or os.getcwd()
|
||||||
script += f"if [ -d {shlex.quote(d)} ]; then\n"
|
script += f"if Path({repr(d)}).exists():\n"
|
||||||
script += f" cd {shlex.quote(d)}\n"
|
script += f" os.chdir({repr(d)})\n"
|
||||||
if cwd is not None:
|
if cwd is not None:
|
||||||
script += "else\n"
|
script += "else:\n"
|
||||||
script += f" echo Failed to change directory to {shlex.quote(d)} 1>&2\n"
|
script += (
|
||||||
script += f" exit 1\n"
|
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
||||||
script += "fi\n"
|
)
|
||||||
|
script += f" sys.exit(1)\n"
|
||||||
|
|
||||||
# Add the environment variables that were given to us
|
# Add the environment variables that were given to us
|
||||||
|
script += "env = dict(os.environ)\n"
|
||||||
for e in env:
|
for e in env:
|
||||||
key, *value = e.split("=", maxsplit=1)
|
key, *value = e.split("=", maxsplit=1)
|
||||||
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||||
if not all(c.isalnum() or c == "_" for c in key):
|
if not all(c.isalnum() or c == "_" for c in key):
|
||||||
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
|
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
|
||||||
continue
|
continue
|
||||||
script += f"export {key}={value}\n"
|
script += f"env[{repr(key)}] = {repr(value)}\n"
|
||||||
|
|
||||||
# Add the environment variables to enable the ring distributed backend
|
# Add the environment variables to enable the ring distributed backend
|
||||||
if hostfile != "":
|
if hostfile != "":
|
||||||
script += "tmpfile=$(mktemp)\n"
|
script += "_, hostfile = tempfile.mkstemp()\n"
|
||||||
script += f"echo {shlex.quote(hostfile)} >$tmpfile\n"
|
script += "with open(hostfile, 'w') as f:\n"
|
||||||
|
script += f" f.write({repr(hostfile)})\n"
|
||||||
if verbose:
|
if verbose:
|
||||||
script += "export MLX_RING_VERBOSE=1\n"
|
script += "env['MLX_RING_VERBOSE'] = '1'\n"
|
||||||
script += "export MLX_HOSTFILE=$tmpfile\n"
|
script += "env['MLX_HOSTFILE'] = hostfile\n"
|
||||||
script += f"export MLX_RANK={rank}\n"
|
script += f"env['MLX_RANK'] = '{rank}'\n"
|
||||||
script += "\n"
|
script += "\n"
|
||||||
|
|
||||||
# Replace the process with the script
|
# Replace the process with the script
|
||||||
script += shlex.join(["exec", *command])
|
script += f"command = [{','.join(map(repr, command))}]\n"
|
||||||
script += "\n"
|
script += "os.execve(command[0], command, env)\n"
|
||||||
|
|
||||||
return script
|
return script
|
||||||
|
|
||||||
@ -233,28 +243,37 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
stop = False
|
stop = False
|
||||||
exit_codes = [None] * len(hosts)
|
exit_codes = [None] * len(hosts)
|
||||||
|
|
||||||
def node_thread(rank, host, hostfile):
|
def node_thread(rank, host, hostfile, input_queue):
|
||||||
is_local = host == "127.0.0.1"
|
is_local = host == "127.0.0.1"
|
||||||
script = make_monitor_script(
|
script = make_monitor_script(
|
||||||
rank, hostfile, args.cwd, args.env, command, args.verbose
|
rank, hostfile, args.cwd, args.env, command, args.verbose
|
||||||
)
|
)
|
||||||
script_b64 = base64.b64encode(script.encode()).decode()
|
script_b64 = base64.b64encode(script.encode()).decode()
|
||||||
cmd = f'echo "{script_b64}" | base64 -d | /bin/bash'
|
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
||||||
if not is_local:
|
if not is_local:
|
||||||
cmd = f"ssh {host} '{cmd}'"
|
cmd = f"ssh {host} '{cmd}'"
|
||||||
p = Popen(
|
p = Popen(
|
||||||
cmd,
|
cmd,
|
||||||
shell=True,
|
shell=True,
|
||||||
|
stdin=PIPE,
|
||||||
stdout=PIPE,
|
stdout=PIPE,
|
||||||
stderr=PIPE,
|
stderr=PIPE,
|
||||||
)
|
)
|
||||||
os.set_blocking(p.stdout.fileno(), False)
|
os.set_blocking(p.stdout.fileno(), False)
|
||||||
os.set_blocking(p.stderr.fileno(), False)
|
os.set_blocking(p.stderr.fileno(), False)
|
||||||
|
os.set_blocking(p.stdin.fileno(), False)
|
||||||
|
|
||||||
# Repeat the stdout and stderr to the local machine
|
# Repeat the stdout and stderr to the local machine
|
||||||
|
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||||
|
to_write = [p.stdin.fileno()]
|
||||||
pidfile = ""
|
pidfile = ""
|
||||||
|
stdin_buffer = b""
|
||||||
while p.poll() is None:
|
while p.poll() is None:
|
||||||
rlist, _, _ = select([p.stdout.fileno(), p.stderr.fileno()], [], [], 1.0)
|
try:
|
||||||
|
stdin_buffer += input_queue.get_nowait()
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||||
for fd in rlist:
|
for fd in rlist:
|
||||||
is_stdout = fd == p.stdout.fileno()
|
is_stdout = fd == p.stdout.fileno()
|
||||||
outfile = sys.stdout if is_stdout else sys.stderr
|
outfile = sys.stdout if is_stdout else sys.stderr
|
||||||
@ -266,6 +285,11 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
msg = msg[0] if msg else ""
|
msg = msg[0] if msg else ""
|
||||||
|
|
||||||
outfile.write(msg)
|
outfile.write(msg)
|
||||||
|
outfile.flush()
|
||||||
|
for fd in wlist:
|
||||||
|
if len(stdin_buffer) > 0:
|
||||||
|
n = os.write(fd, stdin_buffer)
|
||||||
|
stdin_buffer = stdin_buffer[n:]
|
||||||
if stop:
|
if stop:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
break
|
break
|
||||||
@ -310,16 +334,25 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
|
|
||||||
log(args.verbose, "Running", shlex.join(command))
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
input_queues = []
|
||||||
threads = []
|
threads = []
|
||||||
for i, h in enumerate(hosts):
|
for i, h in enumerate(hosts):
|
||||||
if i + 1 == len(hosts):
|
if i + 1 == len(hosts):
|
||||||
time.sleep(1.0)
|
time.sleep(1.0)
|
||||||
t = threading.Thread(target=node_thread, args=(i, h.ssh_hostname, hostfile))
|
input_queues.append(Queue())
|
||||||
|
t = threading.Thread(
|
||||||
|
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
|
||||||
|
)
|
||||||
t.start()
|
t.start()
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
|
|
||||||
|
os.set_blocking(sys.stdin.fileno(), False)
|
||||||
while not stop:
|
while not stop:
|
||||||
time.sleep(1.0)
|
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
|
||||||
|
for fd in rlist:
|
||||||
|
stdin_buffer = os.read(fd, 8192)
|
||||||
|
for q in input_queues:
|
||||||
|
q.put(stdin_buffer)
|
||||||
if any(t.is_alive() for t in threads):
|
if any(t.is_alive() for t in threads):
|
||||||
for i, t in enumerate(threads):
|
for i, t in enumerate(threads):
|
||||||
if not t.is_alive():
|
if not t.is_alive():
|
||||||
@ -730,6 +763,8 @@ def main():
|
|||||||
|
|
||||||
if len(rest) == 0:
|
if len(rest) == 0:
|
||||||
parser.error("No script is provided")
|
parser.error("No script is provided")
|
||||||
|
if rest[0] == "--":
|
||||||
|
rest.pop(0)
|
||||||
|
|
||||||
# Try to extract a list of hosts and corresponding ips
|
# Try to extract a list of hosts and corresponding ips
|
||||||
if args.hostfile is not None:
|
if args.hostfile is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user