Compare commits

..

4 Commits

Author SHA1 Message Date
Angelos Katharopoulos
ebda161a86 Remove old joined script 2025-12-09 13:39:57 -08:00
Angelos Katharopoulos
fa31a4b295 Add more checks and improve errors 2025-12-09 13:36:17 -08:00
Angelos Katharopoulos
9d707ba3b5 Remove python from the launch script 2025-12-09 13:04:37 -08:00
Angelos Katharopoulos
405d30b6e5 Refactor distributed config 2025-12-09 05:58:44 -08:00
5 changed files with 630 additions and 957 deletions

View File

@@ -1,5 +1,6 @@
# Copyright © 2025 Apple Inc.
import argparse
import ipaddress
import json
import sys
@@ -16,6 +17,14 @@ class Host:
rdma: list[Optional[str]]
class OptionalBoolAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if option_string.startswith("--no-"):
setattr(namespace, self.dest, False)
else:
setattr(namespace, self.dest, True)
def positive_number(x):
x = int(x)
if x <= 0:
@@ -26,6 +35,7 @@ def positive_number(x):
def log(verbose, *args, **kwargs):
if not verbose:
return
kwargs["file"] = sys.stderr
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
@@ -50,7 +60,7 @@ def parse_hostlist(parser, hostlist, repeats):
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
hosts.append(Host(i, h, ips, []))
return hosts

View File

@@ -0,0 +1,568 @@
# Copyright © 2025 Apple Inc.
import argparse
import json
import shlex
import sys
import threading
from collections import defaultdict
from dataclasses import dataclass
from subprocess import DEVNULL, run
from typing import Optional
import mlx.core as mx
from .common import (
Host,
OptionalBoolAction,
log,
log_error,
parse_hostfile,
parse_hostlist,
)
@dataclass
class SSHInfo:
can_ssh: bool
has_sudo: bool
def __bool__(self):
return self.can_ssh
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def add_ethernet_ips(hosts, verbose=False):
# Get the ips for each host
for h in hosts:
log(verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
def check_rdma(hosts, verbose=False):
# Check whether the hosts are capable of RDMA over thunderbolt
warn = False
for h in hosts:
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
rdma_devs = (
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
.stdout.strip()
.split()
)
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
if not rdma_devs:
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
warn = True
if warn:
log_warning()
log_warning(
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
)
log_warning()
log_warning(
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
)
log_warning("for instructions on how to enable RDMA.")
def can_auto_setup(hosts, sshinfo, auto_setup=False):
has_sudo = all(info.has_sudo for info in sshinfo)
if not has_sudo and auto_setup:
log_warning(
"Automatic setup requested but the following hosts do not have passwordless sudo"
)
for h, i in zip(hosts, sshinfo):
if not i.has_sudo:
log_warning(" - ", h.ssh_hostname)
return has_sudo
class IPConfigurator:
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
assigned = set()
ips = defaultdict(list)
ip0 = 0
ip1 = 0
for src_node, h in enumerate(tb_hosts):
for src_port, p in enumerate(h.ports):
if not p.connected_to:
continue
if (src_node, src_port) in assigned:
continue
dst_node, dst_port = uuid_reverse_index[p.connected_to]
ip_src = f"192.168.{ip0}.{ip1 + 1}"
ip_dst = f"192.168.{ip0}.{ip1 + 2}"
iface_src = p.iface
iface_dst = tb_hosts[dst_node].ports[dst_port].iface
ips[src_node, dst_node].append((iface_src, ip_src))
ips[dst_node, src_node].append((iface_dst, ip_dst))
assigned.add((src_node, src_port))
assigned.add((dst_node, dst_port))
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs")
self.ips = ips
self.hosts = hosts
self.tb_hosts = tb_hosts
def setup(self, verbose=False, auto_setup=False):
netmask = "255.255.255.252"
for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j in range(len(self.hosts)):
if i == j or (i, j) not in self.ips:
continue
for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " ; ")
command = ["ssh", h.ssh_hostname, command]
log(verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def extract_connectivity(hosts, verbose):
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
return tb_hosts, uuid_reverse_index
def make_connectivity_matrix(tb_hosts, uuid_reverse_index):
connectivity = []
for i, h in enumerate(tb_hosts):
c = [0] * len(tb_hosts)
for p in h.ports:
if p.connected_to is not None:
j, _ = uuid_reverse_index[p.connected_to]
c[j] += 1
connectivity.append(c)
return connectivity
def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
# Make ids per node
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for i, h in enumerate(tb_hosts):
for p in h.ports:
if not p.connected_to:
continue
dst = uuid_reverse_index[p.connected_to]
if dst[0] < i:
continue
print(f" {names[i]} -- {names[dst[0]]}", end="")
print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]')
print("}")
def extract_rings(connectivity):
rings = []
existing_rings = set()
num_nodes = len(connectivity)
def dfs(start_node, node, path, visited):
path.append(node)
visited.add(node)
for j in range(num_nodes):
if connectivity[node][j] <= 0:
continue
if j == start_node:
yield path[:]
if j not in visited:
yield from dfs(start_node, j, path, visited)
path.pop()
visited.remove(node)
for start in range(num_nodes):
for r in dfs(start, start, [], set()):
cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))
rkey = tuple(sorted(r))
if rkey not in existing_rings:
rings.append((r, cnt))
existing_rings.add(rkey)
return sorted(rings, key=lambda x: -len(x[0]))
def check_valid_mesh(hosts, connectivity, strict=True):
num_nodes = len(connectivity)
for i in range(num_nodes):
for j in range(num_nodes):
if i == j:
continue
if connectivity[i][j] <= 0:
if strict:
log_error(
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
)
log_error()
log_error("Try passing --dot to visualize the connectivity")
sys.exit(1)
else:
return False
return True
def check_ssh_connections(hosts):
results = [None] * len(hosts)
def _check(hostname, i):
info = SSHInfo(False, False)
results[i] = info
# Check for ssh
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.can_ssh = result.returncode == 0
if not info.can_ssh:
return
# Check for sudo
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"sudo",
"ls",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.has_sudo = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
return results
def prepare_ethernet_hostfile(args, hosts):
log(args.verbose, f"Preparing an ethernet hostfile")
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_ring(args, hosts, ips, ring, sshinfo):
log(args.verbose, "Prepare a ring hostfile")
ring, count = ring
hostfile = []
for i, node in enumerate(ring):
h = hosts[node]
peer = ring[i - 1]
hostfile.append(
{
"ssh": h.ssh_hostname,
"ips": [ips.ips[node, peer][c][1] for c in range(count)],
"rdma": [],
}
)
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_jaccl(args, hosts, ips, sshinfo):
log(args.verbose, "Prepare a jaccl hostfile")
check_rdma(hosts, args.verbose)
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for i, h in enumerate(hosts):
rdma = []
for j in range(len(hosts)):
if i == j:
rdma.append(None)
else:
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_tb_hostfile(args, hosts, sshinfo):
log(args.verbose, f"Preparing for communication over thunderbolt")
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
if args.dot:
tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)
return
ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)
connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)
if args.backend is None:
rings = extract_rings(connectivity)
has_mesh = check_valid_mesh(hosts, connectivity, False)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring and not has_mesh:
log_error("Neither thunderbolt mesh nor ring found.")
log_error("Perhaps run with --dot to generate a plot of the connectivity.")
sys.exit(1)
elif has_ring:
configure_ring(args, hosts, ips, rings[0], sshinfo)
else:
configure_jaccl(args, hosts, ips, sshinfo)
elif args.backend == "ring":
rings = extract_rings(connectivity)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring:
log_error("Could not find a full ring.")
log_error()
log_error("Try passing --dot to visualize the connectivity")
if len(rings) > 0:
log_error("Rings found:")
for r in rings:
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
sys.exit(1)
configure_ring(args, hosts, ips, rings[0], sshinfo)
elif args.backend == "jaccl":
check_valid_mesh(hosts, connectivity)
configure_jaccl(args, hosts, ips, sshinfo)
def main():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
required=True,
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
"--no-auto-setup",
action=OptionalBoolAction,
nargs=0,
dest="auto_setup",
default=None,
)
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--backend",
choices=["ring", "jaccl"],
default=None,
help="Which distributed backend to configure",
)
args = parser.parse_args()
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
# Check that we can ssh
log(
args.verbose,
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
)
sshinfo = check_ssh_connections(hosts)
# Prepare a hostfile for communication over ethernet using the ips of the
# provided hostnames
if args.over == "ethernet":
prepare_ethernet_hostfile(args, hosts)
# Configure the macs for communication over thunderbolt, both via RDMA and IP
else:
prepare_tb_hostfile(args, hosts, sshinfo)

View File

@@ -1,911 +0,0 @@
# Copyright © 2025 Apple Inc.
import argparse
import base64
import ipaddress
import json
import os
import platform
import shlex
import shutil
import sys
import tempfile
import threading
import time
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from queue import Empty as QueueEmpty
from queue import Queue
from select import select
from subprocess import PIPE, Popen, run
from typing import Optional
import mlx.core as mx
@dataclass
class Host:
rank: int
ssh_hostname: str
ips: list[str]
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def get_num_nvidia_gpus():
result = run(["nvidia-smi", "-L"], capture_output=True, text=True, check=True)
return len(result.stdout.strip().split("\n"))
def extract_rings(hosts, index):
def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
def dfs(start_node, node, path, visited, used_ports):
path.append(node)
visited.add(node)
for j, p in enumerate(hosts[node].ports):
if not usable_port(node, j, used_ports):
continue
next_node, _ = index[p.connected_to]
if next_node == start_node:
yield path[:]
if next_node not in visited:
yield from dfs(start_node, next_node, path, visited, used_ports)
path.pop()
visited.remove(node)
# Concretize maps the found cycle to real thunderbolt ports. It also adds
# those ports to the used set so next cycles can't use them again.
def concretize(cycle, used_ports):
concrete_path = []
for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]):
for j, p in enumerate(hosts[n1].ports):
if not usable_port(n1, j, used_ports):
continue
n2_hat, nj = index[p.connected_to]
if n2 == n2_hat:
concrete_path.append(((n1, j), (n2, nj)))
used_ports.add((n1, j))
used_ports.add((n2, nj))
break
if concrete_path[-1][0][0] != n1:
raise RuntimeError("Couldn't concretize the cycle")
return concrete_path
# Normalize tries to ensure that the cycles have the same direction so we can
# use them together. We achieve this by selecting the direction such that
# the smallest rank hosts connect to larger rank hosts.
def normalize(path):
small_to_large = sum(1 for p in path if p[0][0] < p[1][0])
if small_to_large > len(path) - small_to_large:
return path
else:
return [(p[1], p[0]) for p in path]
rings = []
used_ports = set()
for start_node in range(len(hosts)):
while True:
ring = []
for r in dfs(start_node, start_node, [], set(), used_ports):
if len(r) > len(ring):
ring = r
# Break early since we won't find a bigger ring no matter what
if len(ring) == len(hosts):
break
if not ring:
break
try:
rings.append(normalize(concretize(ring, used_ports)))
except RuntimeError:
if len(rings) > 0:
return rings
raise
return rings
def positive_number(x):
x = int(x)
if x <= 0:
raise ValueError("Number should be positive")
return x
def log(verbose, *args, **kwargs):
if not verbose:
return
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
def log_warning(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend.
Example:
[
{"ssh": "hostname1", "ips": ["123.123.123.1"]},
{"ssh": "hostname2", "ips": ["123.123.123.2"]},
...
{"ssh": "hostnameN", "ips": ["123.123.123.N"]},
]
Args:
hostfile (str): The path to the json file containing the host
information
"""
hostfile = Path(hostfile)
if not hostfile.exists():
parser.error(f"Hostfile {str(hostfile)} doesn't exist")
try:
hosts = []
with open(hostfile) as f:
for i, h in enumerate(json.load(f)):
hosts.append(Host(i, h["ssh"], h.get("ips", [])))
return hosts
except Exception as e:
parser.error(f"Failed to parse hostfile {str(hostfile)} ({str(e)})")
def parse_hostlist(parser, hostlist, repeats):
hosts = []
for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try:
ipaddress.ip_address(h)
ips = [h]
except ValueError:
ips = []
for i in range(repeats):
hosts.append(Host(i, h, ips))
return hosts
def make_monitor_script(rank, hostfile, cwd, env, command, verbose):
# Imports that are used throughout
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
if cwd is not None:
script += "else:\n"
script += (
f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
)
script += f" sys.exit(1)\n"
# Add the environment variables that were given to us
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
if not all(c.isalnum() or c == "_" for c in key):
log_warning(f"'{e}' is an invalid environment variable so it is ignored")
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
# Add the environment variables to enable the ring distributed backend
if hostfile != "":
script += "_, hostfile = tempfile.mkstemp()\n"
script += "with open(hostfile, 'w') as f:\n"
script += f" f.write({repr(hostfile)})\n"
if verbose:
script += "env['MLX_RING_VERBOSE'] = '1'\n"
script += "env['MLX_HOSTFILE'] = hostfile\n"
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
return script
def launch_ring(parser, hosts, args, command):
stop = False
exit_codes = [None] * len(hosts)
def node_thread(rank, host, hostfile, input_queue):
is_local = host == "127.0.0.1"
script = make_monitor_script(
rank, hostfile, args.cwd, args.env, command, args.verbose
)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
if not is_local:
cmd = f"ssh {host} '{cmd}'"
p = Popen(
cmd,
shell=True,
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
)
os.set_blocking(p.stdout.fileno(), False)
os.set_blocking(p.stderr.fileno(), False)
os.set_blocking(p.stdin.fileno(), False)
# Repeat the stdout and stderr to the local machine
to_read = [p.stdout.fileno(), p.stderr.fileno()]
to_write = [p.stdin.fileno(), sys.stdout.fileno(), sys.stderr.fileno()]
pidfile = ""
stdin_buffer = b""
stdout_buffer = b""
stderr_buffer = b""
while p.poll() is None:
try:
stdin_buffer += input_queue.get_nowait()
except QueueEmpty:
pass
rlist, wlist, _ = select(to_read, to_write, [], 1.0)
for fd in rlist:
msg = os.read(fd, 8192).decode(errors="ignore")
# Fetch the PID file first if we haven't already
if pidfile == "":
pidfile, *msg = msg.split("\n", maxsplit=1)
msg = msg[0] if msg else ""
is_stdout = fd == p.stdout.fileno()
if is_stdout:
stdout_buffer += msg.encode()
else:
stderr_buffer += msg.encode()
for fd in wlist:
if fd == p.stdin.fileno() and len(stdin_buffer) > 0:
n = os.write(fd, stdin_buffer)
stdin_buffer = stdin_buffer[n:]
elif fd == sys.stdout.fileno() and len(stdout_buffer) > 0:
n = os.write(fd, stdout_buffer)
stdout_buffer = stdout_buffer[n:]
elif fd == sys.stderr.fileno() and len(stderr_buffer) > 0:
n = os.write(fd, stderr_buffer)
stderr_buffer = stderr_buffer[n:]
if stop:
p.terminate()
break
p.wait()
exit_codes[rank] = p.returncode
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {pidfile}"
if not is_local:
cmd = f"ssh {host} '{cmd}'"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
if c.stdout.strip() == "1":
log_warning(f"Node with rank {rank} was killed")
elif p.returncode != 0:
log_warning(f"Node with rank {rank} exited with code {p.returncode}")
else:
log(args.verbose, f"Node with rank {rank} completed")
if all(len(h.ips) == 0 for h in hosts):
parser.error(
"The ring backend requires IPs to be provided instead of hostnames"
)
port = args.starting_port
ring_hosts = []
for h in hosts:
node = []
for ip in h.ips:
for i in range(args.connections_per_ip):
node.append(f"{ip}:{port}")
port += 1
ring_hosts.append(node)
hostfile = json.dumps(ring_hosts) if len(ring_hosts) > 1 else ""
log(args.verbose, "Running", shlex.join(command))
input_queues = []
threads = []
for i, h in enumerate(hosts):
if i + 1 == len(hosts):
time.sleep(1.0)
input_queues.append(Queue())
t = threading.Thread(
target=node_thread, args=(i, h.ssh_hostname, hostfile, input_queues[-1])
)
t.start()
threads.append(t)
os.set_blocking(sys.stdin.fileno(), False)
while not stop:
rlist, _, _ = select([sys.stdin.fileno()], [], [], 1.0)
for fd in rlist:
stdin_buffer = os.read(fd, 8192)
for q in input_queues:
q.put(stdin_buffer)
if any(t.is_alive() for t in threads):
for i, t in enumerate(threads):
if not t.is_alive():
if exit_codes[i] != 0:
stop = True
break
else:
break
for t in threads:
t.join()
def get_mpi_libname():
try:
ompi_info = run(["which", "ompi_info"], check=True, capture_output=True)
ompi_info = ompi_info.stdout.strip().decode()
if platform.system() == "Darwin":
otool_output = run(
["otool", "-L", ompi_info], check=True, capture_output=True
)
else:
otool_output = run(["ldd", ompi_info], check=True, capture_output=True)
otool_output = otool_output.stdout.decode()
# StopIteration if not found
libmpi_line = next(
filter(lambda line: "libmpi" in line, otool_output.splitlines())
)
return libmpi_line.strip().split()[0].removeprefix("@rpath/")
except:
return None
def launch_mpi(parser, hosts, args, command):
mpirun = run(["which", "mpirun"], check=True, capture_output=True)
mpirun = mpirun.stdout.strip().decode()
# Compatibility with homebrew and pip installs
mpi_libname = get_mpi_libname()
if mpi_libname is not None:
dyld = Path(mpirun).parent.parent / "lib"
args.env = [
f"DYLD_LIBRARY_PATH={str(dyld)}",
f"MLX_MPI_LIBNAME={mpi_libname}",
] + args.env
log(args.verbose, f"Using '{mpirun}'")
with tempfile.NamedTemporaryFile(mode="w") as f:
hosts = Counter((h.ssh_hostname for h in hosts))
for h, n in hosts.items():
print(f"{h} slots={n}", file=f)
f.flush()
cmd = [
mpirun,
"--output",
":raw", # do not line buffer output
"--hostfile",
f.name,
*(["-cwd", args.cwd] if args.cwd else []),
*sum((["-x", e] for e in args.env), []),
*sum([shlex.split(arg) for arg in args.mpi_arg], []),
"--",
*command,
]
log(args.verbose, "Running", " ".join(cmd))
try:
run(cmd)
except KeyboardInterrupt:
pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now.")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": base_env.get(
"NCCL_DEBUG", "INFO" if args.verbose else "DEBUG"
),
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
num_gpus = get_num_nvidia_gpus()
if num_gpus == 0:
raise RuntimeError("Cannot run NCCL backend with no GPUs.")
if args.repeat_hosts > num_gpus:
raise RuntimeError("NCCL requires a separate GPU per process.")
try:
for rank in range(world_size):
env = base_env.copy()
mlx_rank = str(rank % args.repeat_hosts)
env["MLX_RANK"] = mlx_rank
env["CUDA_VISIBLE_DEVICES"] = mlx_rank
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts):
results = [False] * len(hosts)
def _check(hostname, i):
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=PIPE,
stderr=PIPE,
)
results[i] = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
def prepare_tb_ring(args, hosts):
log(
args.verbose,
f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
if args.auto_setup and args.verbose:
log_warning(
"--auto-setup is requested which requires password-less sudo",
"on the remote hosts",
)
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(args.verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(args.verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
# Find the rings by simply walking and marking visited (host, port) tuples
# and keeping the largest rings greedily.
log(args.verbose, "Extracting rings from the parsed connectivity")
rings = extract_rings(tb_hosts, uuid_reverse_index)
# Just output a DOT graphical representation of the found rings
if args.dot:
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for r in rings:
for (i, _), (j, _) in r:
print(f" {names[i]} -- {names[j]};")
print("}")
return
# Assign IPs to each interface such that the interfaces can communicate
ips = {}
pairs = {}
expecting = set()
ip0 = 0
ip1 = 0
netmask = "255.255.255.252"
for r in rings:
for a, b in r:
ips[a] = f"192.168.{ip0}.{ip1 + 1}"
ips[b] = f"192.168.{ip0}.{ip1 + 2}"
pairs[a] = b
pairs[b] = a
expecting.add(b)
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs for the ring")
# Extract the host order from the first ring
hostmap = dict((r[0][0], r[1][0]) for r in rings[0])
first_host = min(hostmap.keys())
order = [first_host]
while hostmap[order[-1]] != first_host:
order.append(hostmap[order[-1]])
# Create the hostfile
hostfile = []
for i in order:
h = hosts[i]
host = {
"ssh": h.ssh_hostname,
"ips": [
ips[i, j]
for j, p in enumerate(tb_hosts[i].ports)
if (i, j) in expecting
],
}
hostfile.append(host)
if not args.hostfile_only:
for i, h in enumerate(hosts):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j, p in enumerate(tb_hosts[i].ports):
if (i, j) not in ips:
continue
iface = p.iface
ip = ips[i, j]
peer = ips[pairs[i, j]]
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if args.auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " && ")
command = ["ssh", h.ssh_hostname, command]
log(args.verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_hostfile(args, hosts):
log(
args.verbose,
f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
# Get the ips for each host
for h in hosts:
log(args.verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def distributed_config():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to configure",
)
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--hostfile-only", action="store_true", help="If set only compute the hostfile"
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
action="store_true",
help="If set we will attempt to automatically configure the machines via ssh",
)
args = parser.parse_args()
if args.backend == "mpi" and args.over == "thunderbolt":
raise ValueError(
(
"The configuration of MPI over thunderbolt is "
"not supported yet by mlx.distributed_config"
)
)
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
if args.over == "thunderbolt":
prepare_tb_ring(args, hosts)
else:
prepare_hostfile(args, hosts)
def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument(
"--print-python",
action="store_true",
help="Print the path to the current python executable and exit",
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--repeat-hosts",
"-n",
type=positive_number,
default=1,
help="Repeat each host a given number of times",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi", "nccl", "jaccl"],
default="nccl" if mx.cuda.is_available() else "ring",
help="Which distributed backend to launch",
)
parser.add_argument(
"--env",
action="append",
default=[],
help="Set environment variables for the jobs",
)
parser.add_argument(
"--mpi-arg",
action="append",
default=[],
help="Arguments to pass directly to mpirun",
)
parser.add_argument(
"--connections-per-ip",
default=1,
type=int,
help="How many connections per ip to use for the ring backend",
)
parser.add_argument(
"--starting-port",
"-p",
type=int,
default=5000,
help="For the ring backend listen on this port increasing by 1 per rank and IP",
)
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
return
if len(rest) == 0:
parser.error("No script is provided")
if rest[0] == "--":
rest.pop(0)
# Try to extract a list of hosts and corresponding ips
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, args.repeat_hosts)
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists():
rest[0:1] = [sys.executable, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if args.backend == "jaccl":
launch_jaccl(parser, hosts, args, rest)
if __name__ == "__main__":
main()

View File

@@ -45,13 +45,11 @@ class CommandProcess:
class RemoteProcess(CommandProcess):
def __init__(self, rank, host, cwd, files, env, command):
def __init__(self, rank, host, python, cwd, files, env, command):
is_local = host == "127.0.0.1"
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
cmd = RemoteProcess.make_launch_script(rank, cwd, files, env, command)
if not is_local:
cmd = f"ssh {host} '{cmd}'"
cmd = f"ssh {host} {shlex.quote(cmd)}"
self._host = host
self._pidfile = None
@@ -90,47 +88,33 @@ class RemoteProcess(CommandProcess):
self._process.wait()
# Kill the remote program if possible
cmd = ""
cmd += f"pid=$(cat {self._pidfile}); "
cmd += "if ps -p $pid >/dev/null; then "
cmd += " kill $pid; "
cmd += " echo 1; "
cmd += "else "
cmd += " echo 0; "
cmd += "fi; "
cmd += f"rm {self._pidfile}"
cmd = RemoteProcess.make_kill_script(self._pidfile)
if not self._is_local:
cmd = f"ssh {self._host} '{cmd}'"
cmd = f"ssh {self._host} {shlex.quote(cmd)}"
c = run(cmd, check=True, shell=True, capture_output=True, text=True)
self._killed = c.stdout.strip() == "1"
@staticmethod
def make_monitor_script(rank, cwd, files, env, command):
# Imports that are used throughout
def make_launch_script(rank, cwd, files, env, command):
script = ""
script += "import os\n"
script += "import sys\n"
script += "import tempfile\n"
script += "from pathlib import Path\n"
# Write the PID to a file so we can kill the process if needed
script += "_, pidfile = tempfile.mkstemp() \n"
script += "open(pidfile, 'w').write(str(os.getpid()))\n"
script += "print(pidfile, flush=True)\n"
script += "pidfile=$(mktemp); "
script += "echo $$ > $pidfile; "
script += "echo $pidfile; "
# Change the working directory if one was requested. Otherwise attempt to
# change to the current one but don't fail if it wasn't possible.
d = cwd or os.getcwd()
script += f"if Path({repr(d)}).exists():\n"
script += f" os.chdir({repr(d)})\n"
script += f"if [[ -d {repr(d)} ]]; then "
script += f" cd {repr(d)}; "
if cwd is not None:
script += "else:\n"
script += f" print('Failed to change directory to', {repr(d)}, file=sys.stderr)\n"
script += f" sys.exit(1)\n"
script += "else "
script += f" echo 'Failed to change directory to' {repr(d)} >2; "
script += "fi; "
# Add the environment variables that were requested
script += "env = dict(os.environ)\n"
for e in env:
key, *value = e.split("=", maxsplit=1)
value = shlex.quote(value[0]) if len(value) > 0 else ""
@@ -139,22 +123,34 @@ class RemoteProcess(CommandProcess):
f"'{e}' is an invalid environment variable so it is ignored"
)
continue
script += f"env[{repr(key)}] = {repr(value)}\n"
script += f"export {key}={value}; "
# Make the temporary files
for env_name, content in files.items():
script += "_, fname = tempfile.mkstemp()\n"
script += "with open(fname, 'w') as f:\n"
script += f" f.write({repr(content)})\n"
script += f"env[{repr(env_name)}] = fname\n"
script += "fname=$(mktemp); "
script += f"echo {shlex.quote(content)} >$fname; "
script += f"export {env_name}=$fname; "
# Finally add the rank
script += f"env['MLX_RANK'] = '{rank}'\n"
script += "\n"
script += f"export MLX_RANK={rank}; "
# Replace the process with the script
script += f"command = [{','.join(map(repr, command))}]\n"
script += "os.execve(command[0], command, env)\n"
script += f"cmd=({' '.join(map(shlex.quote, command))}); "
script += 'exec "${cmd[@]}"'
return script
@staticmethod
def make_kill_script(pidfile):
script = ""
script += f"pid=$(cat {pidfile}); "
script += "if ps -p $pid >/dev/null; then "
script += " kill $pid; "
script += " echo 1; "
script += "else "
script += " echo 0; "
script += "fi; "
script += f"rm {pidfile}"
return script
@@ -309,7 +305,7 @@ def launch_ring(parser, hosts, args, command):
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
@@ -341,6 +337,7 @@ def launch_nccl(parser, hosts, args, command):
(
rank,
h.ssh_hostname,
args.python,
cwd,
{},
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
@@ -374,7 +371,7 @@ def launch_jaccl(parser, hosts, args, command):
_launch_with_io(
RemoteProcess,
[
((rank, h.ssh_hostname, cwd, files, env, command), {})
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
for rank, h in enumerate(hosts)
],
args.verbose,
@@ -503,11 +500,20 @@ def main():
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
parser.add_argument(
"--no-verify-script",
action="store_false",
dest="verify_script",
help="Do not verify that the script exists",
)
parser.add_argument(
"--python", default=sys.executable, help="Use this python on the remote hosts"
)
args, rest = parser.parse_known_args()
if args.print_python:
print(sys.executable)
print(args.python)
return
if len(rest) == 0:
@@ -523,10 +529,10 @@ def main():
# Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists() and script.is_file():
rest[0:1] = [sys.executable, str(script.resolve())]
rest[0:1] = [args.python, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None:
rest[0] = command
else:
elif args.verify_script:
raise ValueError(f"Invalid script or command {rest[0]}")
# Launch

View File

@@ -266,7 +266,7 @@ if __name__ == "__main__":
entry_points = {
"console_scripts": [
"mlx.launch = mlx._distributed_utils.launch:main",
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
"mlx.distributed_config = mlx._distributed_utils.config:main",
]
}
install_requires = []