mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactoring launcher
This commit is contained in:
@@ -1086,18 +1086,17 @@ bool is_available() {
|
|||||||
|
|
||||||
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
std::shared_ptr<GroupImpl> init(bool strict /* = false */) {
|
||||||
const char* dev_file = std::getenv("MLX_IBV_DEVICES");
|
const char* dev_file = std::getenv("MLX_IBV_DEVICES");
|
||||||
const char* coordinator = std::getenv("MLX_IBV_COORDINATOR");
|
const char* coordinator = std::getenv("MLX_JACCL_COORDINATOR");
|
||||||
const char* rank_str = std::getenv("MLX_RANK");
|
const char* rank_str = std::getenv("MLX_RANK");
|
||||||
const char* ring_verbose = std::getenv("MLX_IBV_VERBOSE");
|
|
||||||
|
|
||||||
if (!is_available() || !dev_file || !coordinator || !rank_str) {
|
if (!is_available() || !dev_file || !coordinator || !rank_str) {
|
||||||
if (strict) {
|
if (strict) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[jaccl] You need to provide via environment variables a rank (MLX_RANK), "
|
msg << "[jaccl] You need to provide via environment variables a rank (MLX_RANK), "
|
||||||
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_IBV_COORDINATOR) "
|
<< "a device file (MLX_IBV_DEVICES) and a coordinator ip/port (MLX_JACCL_COORDINATOR) "
|
||||||
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
|
<< "but provided MLX_RANK=\"" << ((rank_str) ? rank_str : "")
|
||||||
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
|
<< "\", MLX_IBV_DEVICES=\"" << ((dev_file) ? dev_file : "")
|
||||||
<< "\" and MLX_IBV_COORDINATOR=\""
|
<< "\" and MLX_JACCL_COORDINATOR=\""
|
||||||
<< ((coordinator) ? coordinator : "");
|
<< ((coordinator) ? coordinator : "");
|
||||||
throw std::runtime_error(msg.str());
|
throw std::runtime_error(msg.str());
|
||||||
}
|
}
|
||||||
|
|||||||
85
python/mlx/_distributed_utils/common.py
Normal file
85
python/mlx/_distributed_utils/common.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Host:
|
||||||
|
rank: int
|
||||||
|
ssh_hostname: str
|
||||||
|
ips: list[str]
|
||||||
|
rdma: list[Optional[str]]
|
||||||
|
|
||||||
|
|
||||||
|
def positive_number(x):
|
||||||
|
x = int(x)
|
||||||
|
if x <= 0:
|
||||||
|
raise ValueError("Number should be positive")
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def log(verbose, *args, **kwargs):
|
||||||
|
if not verbose:
|
||||||
|
return
|
||||||
|
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def log_warning(*args, **kwargs):
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
|
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def log_error(*args, **kwargs):
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
|
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hostlist(parser, hostlist, repeats):
|
||||||
|
hosts = []
|
||||||
|
for i, h in enumerate(hostlist.split(",")):
|
||||||
|
if h == "":
|
||||||
|
raise ValueError("Hostname cannot be empty")
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(h)
|
||||||
|
ips = [h]
|
||||||
|
except ValueError:
|
||||||
|
ips = []
|
||||||
|
for i in range(repeats):
|
||||||
|
hosts.append(Host(i, h, ips))
|
||||||
|
return hosts
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hostfile(parser, hostfile):
|
||||||
|
"""Parse the json hostfile that contains both the hostnames to ssh into and
|
||||||
|
the ips to communicate over when using the ring backend.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
[
|
||||||
|
{"ssh": "hostname1", "ips": ["123.123.123.1"], "rdma": [null, "rdma_en2", "rdma_en3"]},
|
||||||
|
{"ssh": "hostname2", "ips": ["123.123.123.2"], "rdma": ["rdma_en2", null, "rdma_en3"]},
|
||||||
|
...
|
||||||
|
{"ssh": "hostnameN", "ips": ["123.123.123.N"], "rdma": ["rdma_en2", "rdma_en3", null]},
|
||||||
|
]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostfile (str): The path to the json file containing the host
|
||||||
|
information
|
||||||
|
"""
|
||||||
|
hostfile = Path(hostfile)
|
||||||
|
if not hostfile.exists():
|
||||||
|
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
|
||||||
|
|
||||||
|
try:
|
||||||
|
hosts = []
|
||||||
|
with open(hostfile) as f:
|
||||||
|
for i, h in enumerate(json.load(f)):
|
||||||
|
hosts.append(Host(i, h["ssh"], h.get("ips", []), h.get("rdma", [])))
|
||||||
|
return hosts
|
||||||
|
except Exception as e:
|
||||||
|
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
|
||||||
0
python/mlx/_distributed_utils/config.py
Normal file
0
python/mlx/_distributed_utils/config.py
Normal file
@@ -832,7 +832,7 @@ def main():
|
|||||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
choices=["ring", "mpi", "nccl"],
|
choices=["ring", "mpi", "nccl", "jaccl"],
|
||||||
default="nccl" if mx.cuda.is_available() else "ring",
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
help="Which distributed backend to launch",
|
help="Which distributed backend to launch",
|
||||||
)
|
)
|
||||||
@@ -903,6 +903,8 @@ def main():
|
|||||||
launch_mpi(parser, hosts, args, rest)
|
launch_mpi(parser, hosts, args, rest)
|
||||||
if args.backend == "nccl":
|
if args.backend == "nccl":
|
||||||
launch_nccl(parser, hosts, args, rest)
|
launch_nccl(parser, hosts, args, rest)
|
||||||
|
if args.backend == "jaccl":
|
||||||
|
launch_jaccl(parser, hosts, args, rest)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
540
python/mlx/_distributed_utils/launch.py
Normal file
540
python/mlx/_distributed_utils/launch.py
Normal file
@@ -0,0 +1,540 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shlex
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from collections import Counter
|
||||||
|
from itertools import chain
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Empty as QueueEmpty
|
||||||
|
from queue import Queue
|
||||||
|
from select import select
|
||||||
|
from subprocess import PIPE, Popen, run
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .common import log, log_warning, parse_hostfile, parse_hostlist, positive_number
|
||||||
|
|
||||||
|
|
||||||
|
class CommandProcess:
|
||||||
|
@property
|
||||||
|
def process(self):
|
||||||
|
"""Return the Popen object that refers to the current command."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exit_status(self):
|
||||||
|
"""Return a tuple (returncode, killed) for the command. It should be
|
||||||
|
(None, None) while the command is running normally."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def preprocess_output(self, data: str, is_stdout=False):
|
||||||
|
"""Preprocess the output of the command so that extra data can be
|
||||||
|
capture or the format changed on the fly."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
"""Terminate or return the exit code."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteProcess(CommandProcess):
|
||||||
|
def __init__(self, rank, host, cwd, files, env, command):
|
||||||
|
is_local = host == "127.0.0.1"
|
||||||
|
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
|
||||||
|
script_b64 = base64.b64encode(script.encode()).decode()
|
||||||
|
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
||||||
|
if not is_local:
|
||||||
|
cmd = f"ssh {host} '{cmd}'"
|
||||||
|
|
||||||
|
self._host = host
|
||||||
|
self._pidfile = None
|
||||||
|
self._is_local = is_local
|
||||||
|
self._process = Popen(
|
||||||
|
cmd,
|
||||||
|
shell=True,
|
||||||
|
stdin=PIPE,
|
||||||
|
stdout=PIPE,
|
||||||
|
stderr=PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._killed = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def process(self):
|
||||||
|
return self._process
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exit_status(self):
|
||||||
|
return self._process.poll(), self._killed
|
||||||
|
|
||||||
|
def preprocess_output(self, data, is_stdout=False):
|
||||||
|
if self._pidfile is None:
|
||||||
|
pidfile, *rest = data.split("\n", maxsplit=1)
|
||||||
|
self._pidfile = pidfile
|
||||||
|
return rest[0] if rest else ""
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
if self._killed:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.wait()
|
||||||
|
|
||||||
|
# Kill the remote program if possible
|
||||||
|
cmd = ""
|
||||||
|
cmd += f"pid=$(cat {self._pidfile}); "
|
||||||
|
cmd += "if ps -p $pid >/dev/null; then "
|
||||||
|
cmd += " kill $pid; "
|
||||||
|
cmd += " echo 1; "
|
||||||
|
cmd += "else "
|
||||||
|
cmd += " echo 0; "
|
||||||
|
cmd += "fi; "
|
||||||
|
cmd += f"rm {self._pidfile}"
|
||||||
|
if not self._is_local:
|
||||||
|
cmd = f"ssh {self._host} '{cmd}'"
|
||||||
|
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
|
||||||
|
|
||||||
|
self._killed = c.stdout.strip() == "1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_monitor_script(rank, cwd, files, env, command):
|
||||||
|
# Imports that are used throughout
|
||||||
|
script = ""
|
||||||
|
script += "import os\n"
|
||||||
|
script += "import sys\n"
|
||||||
|
script += "import tempfile\n"
|
||||||
|
script += "from pathlib import Path\n"
|
||||||
|
|
||||||
|
# Write the PID to a file so we can kill the process if needed
|
||||||
|
script += "_, pidfile = tempfile.mkstemp() \n"
|
||||||
|
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
|
||||||
|
script += "print(pidfile, flush=True)\n"
|
||||||
|
|
||||||
|
# Change the working directory if one was requested. Otherwise attempt to
|
||||||
|
# change to the current one but don't fail if it wasn't possible.
|
||||||
|
d = cwd or os.getcwd()
|
||||||
|
script += f"if Path({repr(d)}).exists():\n"
|
||||||
|
script += f" os.chdir({repr(d)})\n"
|
||||||
|
if cwd is not None:
|
||||||
|
script += "else:\n"
|
||||||
|
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
|
||||||
|
script += f" sys.exit(1)\n"
|
||||||
|
|
||||||
|
# Add the environment variables that were requested
|
||||||
|
script += "env = dict(os.environ)\n"
|
||||||
|
for e in env:
|
||||||
|
key, *value = e.split("=", maxsplit=1)
|
||||||
|
value = shlex.quote(value[0]) if len(value) > 0 else ""
|
||||||
|
if not all(c.isalnum() or c == "_" for c in key):
|
||||||
|
log_warning(
|
||||||
|
f"'{e}' is an invalid environment variable so it is ignored"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
script += f"env[{repr(key)}] = {repr(value)}\n"
|
||||||
|
|
||||||
|
# Make the temporary files
|
||||||
|
for env_name, content in files.items():
|
||||||
|
script += "_, fname = tempfile.mkstemp()\n"
|
||||||
|
script += "with open(fname, 'w') as f:\n"
|
||||||
|
script += f" f.write({repr(content)})\n"
|
||||||
|
script += f"env[{repr(env_name)}] = fname\n"
|
||||||
|
|
||||||
|
# Finally add the rank
|
||||||
|
script += f"env['MLX_RANK'] = '{rank}'\n"
|
||||||
|
script += "\n"
|
||||||
|
|
||||||
|
# Replace the process with the script
|
||||||
|
script += f"command = [{','.join(map(repr, command))}]\n"
|
||||||
|
script += "os.execve(command[0], command, env)\n"
|
||||||
|
|
||||||
|
return script
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_with_io(command_class, arguments, verbose):
|
||||||
|
stop = False
|
||||||
|
exit_codes = [(None, None)] * len(arguments)
|
||||||
|
|
||||||
|
def _thread_fn(rank, *args, **kwargs):
|
||||||
|
stdin_queue = kwargs.pop("stdin_queue")
|
||||||
|
stdout_queue = kwargs.pop("stdout_queue")
|
||||||
|
stderr_queue = kwargs.pop("stderr_queue")
|
||||||
|
|
||||||
|
command = command_class(rank, *args, **kwargs)
|
||||||
|
p = command.process
|
||||||
|
os.set_blocking(p.stdout.fileno(), False)
|
||||||
|
os.set_blocking(p.stderr.fileno(), False)
|
||||||
|
os.set_blocking(p.stdin.fileno(), False)
|
||||||
|
|
||||||
|
to_read = [p.stdout.fileno(), p.stderr.fileno()]
|
||||||
|
to_write = [p.stdin.fileno()]
|
||||||
|
|
||||||
|
stdin_buffer = b""
|
||||||
|
while p.poll() is None:
|
||||||
|
try:
|
||||||
|
stdin_buffer += stdin_queue.get_nowait()
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
|
||||||
|
for fd in rlist:
|
||||||
|
is_stdout = fd == p.stdout.fileno()
|
||||||
|
msg = os.read(fd, 8192).decode(errors="ignore")
|
||||||
|
msg = command.preprocess_output(msg, is_stdout)
|
||||||
|
if is_stdout:
|
||||||
|
stdout_queue.put(msg.encode())
|
||||||
|
else:
|
||||||
|
stderr_queue.put(msg.encode())
|
||||||
|
for fd in wlist:
|
||||||
|
if len(stdin_buffer) > 0:
|
||||||
|
n = os.write(fd, stdin_buffer)
|
||||||
|
stdin_buffer = stdin_buffer[n:]
|
||||||
|
if stop:
|
||||||
|
command.terminate()
|
||||||
|
break
|
||||||
|
exit_codes[rank] = command.exit_status
|
||||||
|
|
||||||
|
if exit_codes[rank][1]:
|
||||||
|
log_warning(f"Node with rank {rank} was killed")
|
||||||
|
elif exit_codes[rank][0] != 0:
|
||||||
|
log_warning(f"Node with rank {rank} exited with code {exit_codes[rank][0]}")
|
||||||
|
else:
|
||||||
|
log(verbose, f"Node with rank {rank} completed")
|
||||||
|
|
||||||
|
stdin_queues = []
|
||||||
|
stdout_queues = []
|
||||||
|
stderr_queues = []
|
||||||
|
threads = []
|
||||||
|
for i, (args, kwargs) in enumerate(arguments):
|
||||||
|
stdin_queues.append(Queue())
|
||||||
|
stdout_queues.append(Queue())
|
||||||
|
stderr_queues.append(Queue())
|
||||||
|
t = threading.Thread(
|
||||||
|
target=_thread_fn,
|
||||||
|
args=args,
|
||||||
|
kwargs=kwargs
|
||||||
|
| {
|
||||||
|
"stdin_queue": stdin_queues[-1],
|
||||||
|
"stdout_queue": stdout_queues[-1],
|
||||||
|
"stderr_queue": stderr_queues[-1],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
os.set_blocking(sys.stdin.fileno(), False)
|
||||||
|
os.set_blocking(sys.stdout.fileno(), True)
|
||||||
|
os.set_blocking(sys.stderr.fileno(), True)
|
||||||
|
while not stop or any(not q.empty() for q in chain(stdout_queues, stderr_queues)):
|
||||||
|
# Broadcast user input to the jobs
|
||||||
|
rlist, _, _ = select([sys.stdin.fileno()], [], [], 0.1)
|
||||||
|
for fd in rlist:
|
||||||
|
stdin_buffer = os.read(fd, 8192)
|
||||||
|
for q in stdin_queues:
|
||||||
|
q.put(stdin_buffer)
|
||||||
|
|
||||||
|
# Gather job output
|
||||||
|
for q in stdout_queues:
|
||||||
|
try:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stdout.buffer.write(q.get_nowait())
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
for q in stderr_queues:
|
||||||
|
try:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stderr.buffer.write(q.get_nowait())
|
||||||
|
except QueueEmpty:
|
||||||
|
pass
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
|
||||||
|
# Check if all are running and terminate otherwise
|
||||||
|
if any(t.is_alive() for t in threads):
|
||||||
|
for i, t in enumerate(threads):
|
||||||
|
if not t.is_alive():
|
||||||
|
if exit_codes[i][0] != 0:
|
||||||
|
stop = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Wait for the jobs to finish
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Process any remaining outputs
|
||||||
|
for q in stdout_queues:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stdout.buffer.write(q.get())
|
||||||
|
for q in stderr_queues:
|
||||||
|
while not q.empty():
|
||||||
|
sys.stderr.buffer.write(q.get())
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def launch_ring(parser, hosts, args, command):
|
||||||
|
if any(len(h.ips) == 0 for h in hosts):
|
||||||
|
parser.error(
|
||||||
|
"The ring backend requires IPs to be provided instead of hostnames"
|
||||||
|
)
|
||||||
|
|
||||||
|
port = args.starting_port
|
||||||
|
ring_hosts = []
|
||||||
|
for h in hosts:
|
||||||
|
node = []
|
||||||
|
for ip in h.ips:
|
||||||
|
for i in range(args.connections_per_ip):
|
||||||
|
node.append(f"{ip}:{port}")
|
||||||
|
port += 1
|
||||||
|
ring_hosts.append(node)
|
||||||
|
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
|
||||||
|
|
||||||
|
files = {"MLX_HOSTFILE": hostfile}
|
||||||
|
env = args.env
|
||||||
|
if args.verbose:
|
||||||
|
env.append("MLX_RING_VERBOSE=1")
|
||||||
|
cwd = args.cwd
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
((rank, h.ssh_hostname, cwd, files, env, command), {})
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_nccl(parser, hosts, args, command):
|
||||||
|
if not hosts[0].ips:
|
||||||
|
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||||
|
|
||||||
|
master_host = hosts[0].ips[0]
|
||||||
|
master_port = args.nccl_port
|
||||||
|
world_size = len(hosts)
|
||||||
|
|
||||||
|
env = args.env
|
||||||
|
cwd = args.cwd
|
||||||
|
if args.verbose:
|
||||||
|
env.append("NCCL_DEBUG=INFO")
|
||||||
|
env.append(f"NCCL_HOST_IP={master_host}")
|
||||||
|
env.append(f"NCCL_PORT={master_port}")
|
||||||
|
env.append(f"MLX_WORLD_SIZE={world_size}")
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
(
|
||||||
|
(
|
||||||
|
rank,
|
||||||
|
h.ssh_hostname,
|
||||||
|
cwd,
|
||||||
|
{},
|
||||||
|
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
|
||||||
|
command,
|
||||||
|
),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def launch_jaccl(parser, hosts, args, command):
|
||||||
|
if not hosts[0].ips:
|
||||||
|
raise ValueError("Rank 0 should have an IP reachable from all other ranks")
|
||||||
|
|
||||||
|
have_rdmas = all(len(h.rdma) == len(hosts) for h in hosts)
|
||||||
|
have_nulls = all(h.rdma[i] is None for i, h in enumerate(hosts))
|
||||||
|
if not have_rdmas or not have_nulls:
|
||||||
|
raise ValueError("Malformed hostfile for jaccl backend")
|
||||||
|
|
||||||
|
coordinator = hosts[0].ips[0]
|
||||||
|
env = args.env
|
||||||
|
cwd = args.cwd
|
||||||
|
env.append(f"MLX_JACCL_COORDINATOR={coordinator}:{args.starting_port}")
|
||||||
|
files = {"MLX_IBV_DEVICES": json.dumps([h.rdma for h in hosts])}
|
||||||
|
|
||||||
|
log(args.verbose, "Running", shlex.join(command))
|
||||||
|
|
||||||
|
_launch_with_io(
|
||||||
|
RemoteProcess,
|
||||||
|
[
|
||||||
|
((rank, h.ssh_hostname, cwd, files, env, command), {})
|
||||||
|
for rank, h in enumerate(hosts)
|
||||||
|
],
|
||||||
|
args.verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mpi_libname():
|
||||||
|
try:
|
||||||
|
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
|
||||||
|
ompi_info = ompi_info.stdout.strip().decode()
|
||||||
|
|
||||||
|
if platform.system() == "Darwin":
|
||||||
|
otool_output = run(
|
||||||
|
["otool", "-L", ompi_info], check=True, capture_output=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
|
||||||
|
otool_output = otool_output.stdout.decode()
|
||||||
|
|
||||||
|
# StopIteration if not found
|
||||||
|
libmpi_line = next(
|
||||||
|
filter(lambda line: "libmpi" in line, otool_output.splitlines())
|
||||||
|
)
|
||||||
|
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def launch_mpi(parser, hosts, args, command):
|
||||||
|
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
|
||||||
|
mpirun = mpirun.stdout.strip().decode()
|
||||||
|
|
||||||
|
# Compatibility with homebrew and pip installs
|
||||||
|
mpi_libname = get_mpi_libname()
|
||||||
|
if mpi_libname is not None:
|
||||||
|
dyld = Path(mpirun).parent.parent / "lib"
|
||||||
|
args.env = [
|
||||||
|
f"DYLD_LIBRARY_PATH={str(dyld)}",
|
||||||
|
f"MLX_MPI_LIBNAME={mpi_libname}",
|
||||||
|
] + args.env
|
||||||
|
|
||||||
|
log(args.verbose, f"Using '{mpirun}'")
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||||
|
hosts = Counter((h.ssh_hostname for h in hosts))
|
||||||
|
for h, n in hosts.items():
|
||||||
|
print(f"{h} slots={n}", file=f)
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
mpirun,
|
||||||
|
"--output",
|
||||||
|
":raw", # do not line buffer output
|
||||||
|
"--hostfile",
|
||||||
|
f.name,
|
||||||
|
*(["-cwd", args.cwd] if args.cwd else []),
|
||||||
|
*sum((["-x", e] for e in args.env), []),
|
||||||
|
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
|
||||||
|
"--",
|
||||||
|
*command,
|
||||||
|
]
|
||||||
|
log(args.verbose, "Running", " ".join(cmd))
|
||||||
|
try:
|
||||||
|
run(cmd)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
||||||
|
parser.add_argument(
|
||||||
|
"--print-python",
|
||||||
|
action="store_true",
|
||||||
|
help="Print the path to the current python executable and exit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat-hosts",
|
||||||
|
"-n",
|
||||||
|
type=positive_number,
|
||||||
|
default=1,
|
||||||
|
help="Repeat each host a given number of times",
|
||||||
|
)
|
||||||
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["ring", "mpi", "nccl", "jaccl"],
|
||||||
|
default="nccl" if mx.cuda.is_available() else "ring",
|
||||||
|
help="Which distributed backend to launch",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Set environment variables for the jobs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mpi-arg",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Arguments to pass directly to mpirun",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--connections-per-ip",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="How many connections per ip to use for the ring backend",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--starting-port",
|
||||||
|
"-p",
|
||||||
|
type=int,
|
||||||
|
default=32323,
|
||||||
|
help="For the ring backend listen on this port increasing by 1 per rank and IP",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cwd", help="Set the working directory on each node to the provided one"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--nccl-port",
|
||||||
|
type=int,
|
||||||
|
default=12345,
|
||||||
|
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args, rest = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.print_python:
|
||||||
|
print(sys.executable)
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(rest) == 0:
|
||||||
|
parser.error("No script is provided")
|
||||||
|
if rest[0] == "--":
|
||||||
|
rest.pop(0)
|
||||||
|
|
||||||
|
# Try to extract a list of hosts and corresponding ips
|
||||||
|
if args.hostfile is not None:
|
||||||
|
hosts = parse_hostfile(parser, args.hostfile)
|
||||||
|
else:
|
||||||
|
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
|
||||||
|
|
||||||
|
# Check if the script is a file and convert it to a full path
|
||||||
|
if (script := Path(rest[0])).exists() and script.is_file():
|
||||||
|
rest[0:1] = [sys.executable, str(script.resolve())]
|
||||||
|
elif (command := shutil.which(rest[0])) is not None:
|
||||||
|
rest[0] = command
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||||
|
|
||||||
|
# Launch
|
||||||
|
if args.backend == "ring":
|
||||||
|
launch_ring(parser, hosts, args, rest)
|
||||||
|
if args.backend == "mpi":
|
||||||
|
launch_mpi(parser, hosts, args, rest)
|
||||||
|
if args.backend == "nccl":
|
||||||
|
launch_nccl(parser, hosts, args, rest)
|
||||||
|
if args.backend == "jaccl":
|
||||||
|
launch_jaccl(parser, hosts, args, rest)
|
||||||
4
setup.py
4
setup.py
@@ -257,8 +257,8 @@ if __name__ == "__main__":
|
|||||||
}
|
}
|
||||||
entry_points = {
|
entry_points = {
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"mlx.launch = mlx.distributed_run:main",
|
"mlx.launch = mlx._distributed_utils.launch:main",
|
||||||
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
install_requires = []
|
install_requires = []
|
||||||
|
|||||||
Reference in New Issue
Block a user