Refactor distributed config

This commit is contained in:
Angelos Katharopoulos
2025-12-09 05:58:44 -08:00
parent cd4b12ce1b
commit 405d30b6e5
4 changed files with 556 additions and 9 deletions

View File

@@ -1,5 +1,6 @@
# Copyright © 2025 Apple Inc. # Copyright © 2025 Apple Inc.
import argparse
import ipaddress import ipaddress
import json import json
import sys import sys
@@ -16,6 +17,14 @@ class Host:
rdma: list[Optional[str]] rdma: list[Optional[str]]
class OptionalBoolAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if option_string.startswith("--no-"):
setattr(namespace, self.dest, False)
else:
setattr(namespace, self.dest, True)
def positive_number(x): def positive_number(x):
x = int(x) x = int(x)
if x <= 0: if x <= 0:
@@ -50,7 +59,7 @@ def parse_hostlist(parser, hostlist, repeats):
except ValueError: except ValueError:
ips = [] ips = []
for i in range(repeats): for i in range(repeats):
hosts.append(Host(i, h, ips)) hosts.append(Host(i, h, ips, []))
return hosts return hosts

View File

@@ -0,0 +1,526 @@
# Copyright © 2025 Apple Inc.
import argparse
import json
import shlex
import sys
import threading
from collections import defaultdict
from dataclasses import dataclass
from subprocess import DEVNULL, run
from typing import Optional
import mlx.core as mx
from .common import (
Host,
OptionalBoolAction,
log,
log_error,
parse_hostfile,
parse_hostlist,
)
@dataclass
class SSHInfo:
can_ssh: bool
has_sudo: bool
def __bool__(self):
return self.can_ssh
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def add_ethernet_ips(hosts, verbose=False):
# Get the ips for each host
for h in hosts:
log(verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
class IPConfigurator:
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
assigned = set()
ips = defaultdict(list)
ip0 = 0
ip1 = 0
for src_node, h in enumerate(tb_hosts):
for src_port, p in enumerate(h.ports):
if not p.connected_to:
continue
if (src_node, src_port) in assigned:
continue
dst_node, dst_port = uuid_reverse_index[p.connected_to]
ip_src = f"192.168.{ip0}.{ip1 + 1}"
ip_dst = f"192.168.{ip0}.{ip1 + 2}"
iface_src = p.iface
iface_dst = tb_hosts[dst_node].ports[dst_port].iface
ips[src_node, dst_node].append((iface_src, ip_src))
ips[dst_node, src_node].append((iface_dst, ip_dst))
assigned.add((src_node, src_port))
assigned.add((dst_node, dst_port))
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs")
self.ips = ips
self.hosts = hosts
self.tb_hosts = tb_hosts
def setup(self, verbose=False, auto_setup=False):
netmask = "255.255.255.252"
for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j in range(len(self.hosts)):
if i == j or (i, j) not in self.ips:
continue
for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " ; ")
command = ["ssh", h.ssh_hostname, command]
log(verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def extract_connectivity(hosts, verbose):
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
uuid = t.get("domain_uuid_key")
if uuid is None:
continue
name = t["device_name_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
items = t.get("_items", [])
connected_items = [item for item in items if "domain_uuid_key" in item]
connected_to = (
connected_items[0]["domain_uuid_key"] if connected_items else None
)
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
return tb_hosts, uuid_reverse_index
def make_connectivity_matrix(tb_hosts, uuid_reverse_index):
connectivity = []
for i, h in enumerate(tb_hosts):
c = [0] * len(tb_hosts)
for p in h.ports:
if p.connected_to is not None:
j, _ = uuid_reverse_index[p.connected_to]
c[j] += 1
connectivity.append(c)
return connectivity
def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
# Make ids per node
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for i, h in enumerate(tb_hosts):
for p in h.ports:
if not p.connected_to:
continue
dst = uuid_reverse_index[p.connected_to]
if dst[0] < i:
continue
print(f" {names[i]} -- {names[dst[0]]}", end="")
print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]')
print("}")
def extract_rings(connectivity):
rings = []
existing_rings = set()
num_nodes = len(connectivity)
def dfs(start_node, node, path, visited):
path.append(node)
visited.add(node)
for j in range(num_nodes):
if connectivity[node][j] <= 0:
continue
if j == start_node:
yield path[:]
if j not in visited:
yield from dfs(start_node, j, path, visited)
path.pop()
visited.remove(node)
for start in range(num_nodes):
for r in dfs(start, start, [], set()):
cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))
rkey = tuple(sorted(r))
if rkey not in existing_rings:
rings.append((r, cnt))
existing_rings.add(rkey)
return sorted(rings, key=lambda x: -len(x[0]))
def check_valid_mesh(hosts, connectivity, strict=True):
num_nodes = len(connectivity)
for i in range(num_nodes):
for j in range(num_nodes):
if i == j:
continue
if connectivity[i][j] <= 0:
if strict:
log_error(
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
)
sys.exit(1)
else:
return False
return True
def check_ssh_connections(hosts):
results = [None] * len(hosts)
def _check(hostname, i):
info = SSHInfo(False, False)
results[i] = info
# Check for ssh
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.can_ssh = result.returncode == 0
if not info.can_ssh:
return
# Check for sudo
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"sudo",
"ls",
],
stdout=DEVNULL,
stderr=DEVNULL,
)
info.has_sudo = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
return results
def prepare_ethernet_hostfile(args, hosts):
log(args.verbose, f"Preparing an ethernet hostfile")
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_ring(args, hosts, ips, ring):
log(args.verbose, "Prepare a ring hostfile")
ring, count = ring
hostfile = []
for i, node in enumerate(ring):
h = hosts[node]
peer = ring[i - 1]
hostfile.append(
{
"ssh": h.ssh_hostname,
"ips": [ips.ips[node, peer][c][1] for c in range(count)],
"rdma": [],
}
)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def configure_jaccl(args, hosts, ips):
log(args.verbose, "Prepare a jaccl hostfile")
add_ethernet_ips(hosts)
hostfile = []
for i, h in enumerate(hosts):
rdma = []
for j in range(len(hosts)):
if i == j:
rdma.append(None)
else:
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_tb_hostfile(args, hosts):
log(args.verbose, f"Preparing for communication over thunderbolt")
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
if args.dot:
tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)
return
ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)
connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)
if args.backend is None:
rings = extract_rings(connectivity)
has_mesh = check_valid_mesh(hosts, connectivity, False)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring and not has_mesh:
log_error("Neither thunderbolt mesh nor ring found.")
log_error("Perhaps run with --dot to generate a plot of the connectivity.")
sys.exit(1)
elif has_ring:
configure_ring(args, hosts, ips, rings[0])
else:
configure_jaccl(args, hosts, ips)
elif args.backend == "ring":
rings = extract_rings(connectivity)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring:
log_error("Could not find a full ring.")
if len(rings) > 0:
log_error("Rings found:")
for r in rings:
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
sys.exit(1)
configure_ring(args, hosts, ips, rings[0])
elif args.backend == "jaccl":
check_valid_mesh(hosts, connectivity)
configure_jaccl(args, hosts, ips)
def main():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
required=True,
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
"--no-auto-setup",
action=OptionalBoolAction,
nargs=0,
dest="auto_setup",
default=None,
)
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--backend",
choices=["ring", "jaccl"],
default=None,
help="Which distributed backend to configure",
)
# parser.add_argument(
# "--hostfile-only", action="store_true", help="If set only compute the hostfile"
# )
args = parser.parse_args()
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
# Check that we can ssh
log(
args.verbose,
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
)
check_ssh_connections(hosts)
# Prepare a hostfile for communication over ethernet using the ips of the
# provided hostnames
if args.over == "ethernet":
prepare_ethernet_hostfile(args, hosts)
# Configure the macs for communication over thunderbolt, both via RDMA and IP
else:
prepare_tb_hostfile(args, hosts)

View File

@@ -45,11 +45,13 @@ class CommandProcess:
class RemoteProcess(CommandProcess): class RemoteProcess(CommandProcess):
def __init__(self, rank, host, cwd, files, env, command): def __init__(self, rank, host, python, cwd, files, env, command):
is_local = host == "127.0.0.1" is_local = host == "127.0.0.1"
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command) script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
script_b64 = base64.b64encode(script.encode()).decode() script_b64 = base64.b64encode(script.encode()).decode()
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"' cmd = (
f'{python} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
)
if not is_local: if not is_local:
cmd = f"ssh {host} '{cmd}'" cmd = f"ssh {host} '{cmd}'"
@@ -309,7 +311,7 @@ def launch_ring(parser, hosts, args, command):
_launch_with_io( _launch_with_io(
RemoteProcess, RemoteProcess,
[ [
((rank, h.ssh_hostname, cwd, files, env, command), {}) ((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
for rank, h in enumerate(hosts) for rank, h in enumerate(hosts)
], ],
args.verbose, args.verbose,
@@ -341,6 +343,7 @@ def launch_nccl(parser, hosts, args, command):
( (
rank, rank,
h.ssh_hostname, h.ssh_hostname,
args.python,
cwd, cwd,
{}, {},
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"], env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
@@ -374,7 +377,7 @@ def launch_jaccl(parser, hosts, args, command):
_launch_with_io( _launch_with_io(
RemoteProcess, RemoteProcess,
[ [
((rank, h.ssh_hostname, cwd, files, env, command), {}) ((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
for rank, h in enumerate(hosts) for rank, h in enumerate(hosts)
], ],
args.verbose, args.verbose,
@@ -503,11 +506,20 @@ def main():
default=12345, default=12345,
help="The port to use for the NCCL communication (only for nccl backend)", help="The port to use for the NCCL communication (only for nccl backend)",
) )
parser.add_argument(
"--no-verify-script",
action="store_false",
dest="verify_script",
help="Do not verify that the script exists",
)
parser.add_argument(
"--python", default=sys.executable, help="Use this python on the remote hosts"
)
args, rest = parser.parse_known_args() args, rest = parser.parse_known_args()
if args.print_python: if args.print_python:
print(sys.executable) print(args.python)
return return
if len(rest) == 0: if len(rest) == 0:
@@ -523,10 +535,10 @@ def main():
# Check if the script is a file and convert it to a full path # Check if the script is a file and convert it to a full path
if (script := Path(rest[0])).exists() and script.is_file(): if (script := Path(rest[0])).exists() and script.is_file():
rest[0:1] = [sys.executable, str(script.resolve())] rest[0:1] = [args.python, str(script.resolve())]
elif (command := shutil.which(rest[0])) is not None: elif (command := shutil.which(rest[0])) is not None:
rest[0] = command rest[0] = command
else: elif args.verify_script:
raise ValueError(f"Invalid script or command {rest[0]}") raise ValueError(f"Invalid script or command {rest[0]}")
# Launch # Launch

View File

@@ -266,7 +266,7 @@ if __name__ == "__main__":
entry_points = { entry_points = {
"console_scripts": [ "console_scripts": [
"mlx.launch = mlx._distributed_utils.launch:main", "mlx.launch = mlx._distributed_utils.launch:main",
# "mlx.distributed_config = mlx.distributed_run:distributed_config", "mlx.distributed_config = mlx._distributed_utils.config:main",
] ]
} }
install_requires = [] install_requires = []