Small distributed launch helper (#1810)

This commit is contained in:
Angelos Katharopoulos 2025-01-29 17:55:04 -08:00 committed by GitHub
parent 4758c8baa1
commit ded914f442
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 338 additions and 26 deletions

View File

@ -160,7 +160,7 @@ jobs:
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
/bin/bash python/tests/run_ring_test.sh mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py
- run: - run:
name: Build example extension name: Build example extension
command: | command: |

View File

@ -0,0 +1,336 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import ipaddress
import json
import os
import shlex
import sys
import tempfile
import threading
import time
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from select import select
from subprocess import PIPE, Popen, run
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", [])))
return hosts
except e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
return hosts
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
script = ""
# Write the PID to a file so we can kill the process if needed
script += "pidfile=$(mktemp)\n"
script += "echo $$ >$pidfile\n"
script += "echo $pidfile\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if [ -d {shlex.quote(d)} ]; then\n"
script += f" cd {shlex.quote(d)}\n"
if cwd is not None:
script += "else\n"
script += f" echo Failed to change directory to {shlex.quote(d)} 1>&2\n"
script += f" exit 1\n"
script += "fi\n"
# Add the environment variables that were given to us
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
continue
script += f"export {key}={value}\n"
# Add the environment variables to enable the ring distributed backend
if hostfile != "":
script += "tmpfile=$(mktemp)\n"
script += f"echo {shlex.quote(hostfile)} >$tmpfile\n"
if verbose:
script += "export MLX_RING_VERBOSE=1\n"
script += "export MLX_HOSTFILE=$tmpfile\n"
script += f"export MLX_RANK={rank}\n"
script += "\n"
# Replace the process with the script
script += shlex.join(["exec", sys.executable, *command])
script += "\n"
return script
def launch_ring(parser, hosts, args, command):
stop = False
exit_codes = [None] * len(hosts)
def node_thread(rank, host, hostfile):
is_local = host == "127.0.0.1"
script = make_monitor_script(
rank, hostfile, args.cwd, args.env, command, args.verbose
)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'echo "{script_b64}" | base64 -d | /bin/bash'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
p = Popen(
cmd,
shell=True,
stdout=PIPE,
stderr=PIPE,
)
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
# Repeat the stdout and stderr to the local machine
pidfile = ""
while p.poll() is None:
rlist, _, _ = select([p.stdout.fileno(), p.stderr.fileno()], [], [], 1.0)
for fd in rlist:
is_stdout = fd == p.stdout.fileno()
outfile = sys.stdout if is_stdout else sys.stderr
msg = os.read(fd, 8192).decode(errors="ignore")
# Fetch the PID file first if we haven't already
if pidfile == "":
pidfile, *msg = msg.split("\n", maxsplit=1)
msg = msg[0] if msg else ""
outfile.write(msg)
if stop:
p.terminate()
break
p.wait()
exit_codes[rank] = p.returncode
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {pidfile}"
if not is_local:
cmd = f"ssh {host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
if c.stdout.strip() == "1":
log_warning(f"Node with rank {rank} was killed")
elif p.returncode != 0:
log_warning(f"Node with rank {rank} exited with code {p.returncode}")
else:
log(args.verbose, f"Node with rank {rank} completed")
if all(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = 5000
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
log(args.verbose, "Running", shlex.join([sys.executable, *command]))
threads = []
for i, h in enumerate(hosts):
if i + 1 == len(hosts):
time.sleep(1.0)
t = threading.Thread(target=node_thread, args=(i, h.ssh_hostname, hostfile))
t.start()
threads.append(t)
while not stop:
time.sleep(1.0)
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i] != 0:
stop = True
break
else:
break
for t in threads:
t.join()
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Homebrew libmpi doesn't work with anaconda python out of the box.
# TODO: Check if we should do this with every mpirun
if "homebrew" in mpirun:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [f"DYLD_LIBRARY_PATH={str(dyld)}"] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
sys.executable,
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi"],
default="ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
args, rest = parser.parse_known_args()
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
script = Path(rest[0])
if script.exists():
rest[0] = str(script.resolve())
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
elif args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if __name__ == "__main__":
main()

View File

@ -1,25 +0,0 @@
#!/bin/bash
tmpfile=$(mktemp)
cat <<HOSTFILE >$tmpfile
[
["127.0.0.1:5000"],
["127.0.0.1:5001"],
["127.0.0.1:5002"],
["127.0.0.1:5003"],
["127.0.0.1:5004"],
["127.0.0.1:5005"],
["127.0.0.1:5006"],
["127.0.0.1:5007"]
]
HOSTFILE
ring_test="$(dirname ${BASH_SOURCE[0]})/ring_test_distributed.py"
for i in {0..7}; do
if (($i == 7)); then
sleep 1
fi
DEVICE=cpu MLX_RING_VERBOSE=1 MLX_HOSTFILE=$tmpfile MLX_RANK=$i python $ring_test &
done
wait

View File

@ -194,6 +194,7 @@ if __name__ == "__main__":
"typing_extensions", "typing_extensions",
], ],
}, },
entry_points={"console_scripts": ["mlx.launch = mlx.distributed_run:main"]},
ext_modules=[CMakeExtension("mlx.core")], ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
zip_safe=False, zip_safe=False,