Refactoring launcher

This commit is contained in:
Angelos Katharopoulos
2025-12-01 16:31:57 -08:00
parent 8fab4f0929
commit dd91ee9534
6 changed files with 633 additions and 7 deletions

View File

@@ -0,0 +1,85 @@
# Copyright © 2025 Apple Inc.
import ipaddress
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
rdma: list[Optional[str]]
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
return hosts
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
return hosts
except Exception as e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")

View File

View File

@@ -832,7 +832,7 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
@@ -903,6 +903,8 @@ def main():
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)
if __name__ == "__main__":

View File

@@ -0,0 +1,540 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import json
import os
import shlex
import shutil
import sys
import tempfile
import threading
from collections import Counter
from itertools import chain
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
import mlx.core as mx
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
class CommandProcess:
@property
def process(self):
"""Return the Popen object that refers to the current command."""
raise NotImplementedError()
@property
def exit_status(self):
"""Return a tuple (returncode, killed) for the command. It should be
(None, None) while the command is running normally."""
raise NotImplementedError()
def preprocess_output(self, data: str, is_stdout=False):
"""Preprocess the output of the command so that extra data can be
capture or the format changed on the fly."""
raise NotImplementedError()
def terminate(self):
"""Terminate or return the exit code."""
raise NotImplementedError()
class RemoteProcess(CommandProcess):
def __init__(self, rank, host, cwd, files, env, command):
is_local = host == "127.0.0.1"
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
self._host = host
self._pidfile = None
self._is_local = is_local
self._process = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
self._killed = False
@property
def process(self):
return self._process
@property
def exit_status(self):
return self._process.poll(), self._killed
def preprocess_output(self, data, is_stdout=False):
if self._pidfile is None:
pidfile, *rest = data.split("\n", maxsplit=1)
self._pidfile = pidfile
return rest[0] if rest else ""
return data
def terminate(self):
if self._killed:
return
self._process.terminate()
self._process.wait()
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {self._pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {self._pidfile}"
if not self._is_local:
cmd = f"ssh {self._host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
self._killed = c.stdout.strip() == "1"
@staticmethod
def make_monitor_script(rank, cwd, files, env, command):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else:\n"
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
script += f" sys.exit(1)\n"
# Add the environment variables that were requested
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(
f"'{e}' is an invalid environment variable so it is ignored"
)
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
# Make the temporary files
for env_name, content in files.items():
script += "_, fname = tempfile.mkstemp()\n"
script += "with open(fname, 'w') as f:\n"
script += f" f.write({repr(content)})\n"
script += f"env[{repr(env_name)}] = fname\n"
# Finally add the rank
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
def _launch_with_io(command_class, arguments, verbose):
stop = False
exit_codes = [(None, None)] * len(arguments)
def _thread_fn(rank, *args, **kwargs):
stdin_queue = kwargs.pop("stdin_queue")
stdout_queue = kwargs.pop("stdout_queue")
stderr_queue = kwargs.pop("stderr_queue")
command = command_class(rank, *args, **kwargs)
p = command.process
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno()]
stdin_buffer = b""
while p.poll() is None:
try:
stdin_buffer += stdin_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
msg = os.read(fd, 8192).decode(errors="ignore")
msg = command.preprocess_output(msg, is_stdout)
if is_stdout:
stdout_queue.put(msg.encode())
else:
stderr_queue.put(msg.encode())
for fd in wlist:
if len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
if stop:
command.terminate()
break
exit_codes[rank] = command.exit_status
if exit_codes[rank][1]:
log_warning(f"Node with rank {rank} was killed")
elif exit_codes[rank][0] != 0:
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
else:
log(verbose, f"Node with rank {rank} completed")
stdin_queues = []
stdout_queues = []
stderr_queues = []
threads = []
for i, (args, kwargs) in enumerate(arguments):
stdin_queues.append(Queue())
stdout_queues.append(Queue())
stderr_queues.append(Queue())
t = threading.Thread(
target=_thread_fn,
args=args,
kwargs=kwargs
| {
"stdin_queue": stdin_queues[-1],
"stdout_queue": stdout_queues[-1],
"stderr_queue": stderr_queues[-1],
},
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
os.set_blocking(sys.stdout.fileno(), True)
os.set_blocking(sys.stderr.fileno(), True)
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
# Broadcast user input to the jobs
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in stdin_queues:
q.put(stdin_buffer)
# Gather job output
for q in stdout_queues:
try:
while not q.empty():
sys.stdout.buffer.write(q.get_nowait())
except QueueEmpty:
pass
for q in stderr_queues:
try:
while not q.empty():
sys.stderr.buffer.write(q.get_nowait())
except QueueEmpty:
pass
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
# Check if all are running and terminate otherwise
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i][0] != 0:
stop = True
break
else:
break
# Wait for the jobs to finish
for t in threads:
t.join()
# Process any remaining outputs
for q in stdout_queues:
while not q.empty():
sys.stdout.buffer.write(q.get())
for q in stderr_queues:
while not q.empty():
sys.stderr.buffer.write(q.get())
sys.stdout.buffer.flush()
sys.stderr.buffer.flush()
def launch_ring(parser, hosts, args, command):
if any(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
files = {"MLX_HOSTFILE": hostfile}
env = args.env
if args.verbose:
env.append("MLX_RING_VERBOSE=1")
cwd = args.cwd
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
)
def launch_nccl(parser, hosts, args, command):
if not hosts[0].ips:
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
master_host = hosts[0].ips[0]
master_port = args.nccl_port
world_size = len(hosts)
env = args.env
cwd = args.cwd
if args.verbose:
env.append("NCCL_DEBUG=INFO")
env.append(f"NCCL_HOST_IP={master_host}")
env.append(f"NCCL_PORT={master_port}")
env.append(f"MLX_WORLD_SIZE={world_size}")
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
(
(
rank,
h.ssh_hostname,
cwd,
{},
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
command,
),
{},
)
for rank, h in enumerate(hosts)
],
args.verbose,
)
def launch_jaccl(parser, hosts, args, command):
if not hosts[0].ips:
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
if not have_rdmas or not have_nulls:
raise ValueError("Malformed hostfile for jaccl backend")
coordinator = hosts[0].ips[0]
env = args.env
cwd = args.cwd
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
log(args.verbose, "Running", shlex.join(command))
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
)
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--output",
":raw", # do not line buffer output
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=32323,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists() and script.is_file():
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)