From 405d30b6e5e5c87a67d3958604fb9145ae130b31 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 9 Dec 2025 05:58:44 -0800 Subject: [PATCH] Refactor distributed config --- python/mlx/_distributed_utils/common.py | 11 +- python/mlx/_distributed_utils/config.py | 526 ++++++++++++++++++++++++ python/mlx/_distributed_utils/launch.py | 26 +- setup.py | 2 +- 4 files changed, 556 insertions(+), 9 deletions(-) diff --git a/python/mlx/_distributed_utils/common.py b/python/mlx/_distributed_utils/common.py index e2a1b327d..a466668ff 100644 --- a/python/mlx/_distributed_utils/common.py +++ b/python/mlx/_distributed_utils/common.py @@ -1,5 +1,6 @@ # Copyright © 2025 Apple Inc. +import argparse import ipaddress import json import sys @@ -16,6 +17,14 @@ class Host: rdma: list[Optional[str]] +class OptionalBoolAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + if option_string.startswith("--no-"): + setattr(namespace, self.dest, False) + else: + setattr(namespace, self.dest, True) + + def positive_number(x): x = int(x) if x <= 0: @@ -50,7 +59,7 @@ def parse_hostlist(parser, hostlist, repeats): except ValueError: ips = [] for i in range(repeats): - hosts.append(Host(i, h, ips)) + hosts.append(Host(i, h, ips, [])) return hosts diff --git a/python/mlx/_distributed_utils/config.py b/python/mlx/_distributed_utils/config.py index e69de29bb..07750125d 100644 --- a/python/mlx/_distributed_utils/config.py +++ b/python/mlx/_distributed_utils/config.py @@ -0,0 +1,526 @@ +# Copyright © 2025 Apple Inc. + +import argparse +import json +import shlex +import sys +import threading +from collections import defaultdict +from dataclasses import dataclass +from subprocess import DEVNULL, run +from typing import Optional + +import mlx.core as mx + +from .common import ( + Host, + OptionalBoolAction, + log, + log_error, + parse_hostfile, + parse_hostlist, +) + + +@dataclass +class SSHInfo: + can_ssh: bool + has_sudo: bool + + def __bool__(self): + return self.can_ssh + + +@dataclass +class ThunderboltPort: + iface: str + uuid: str + connected_to: Optional[str] + + +@dataclass +class ThunderboltHost: + name: str + ports: list[ThunderboltPort] + + +def add_ethernet_ips(hosts, verbose=False): + # Get the ips for each host + for h in hosts: + log(verbose, "Getting the ip from", h.ssh_hostname) + h.ips.append( + run( + ["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"], + capture_output=True, + text=True, + ).stdout.strip() + ) + + +class IPConfigurator: + def __init__(self, hosts, tb_hosts, uuid_reverse_index): + assigned = set() + ips = defaultdict(list) + ip0 = 0 + ip1 = 0 + for src_node, h in enumerate(tb_hosts): + for src_port, p in enumerate(h.ports): + if not p.connected_to: + continue + if (src_node, src_port) in assigned: + continue + + dst_node, dst_port = uuid_reverse_index[p.connected_to] + + ip_src = f"192.168.{ip0}.{ip1 + 1}" + ip_dst = f"192.168.{ip0}.{ip1 + 2}" + iface_src = p.iface + iface_dst = tb_hosts[dst_node].ports[dst_port].iface + + ips[src_node, dst_node].append((iface_src, ip_src)) + ips[dst_node, src_node].append((iface_dst, ip_dst)) + + assigned.add((src_node, src_port)) + assigned.add((dst_node, dst_port)) + + ip1 += 4 + if ip1 > 255: + ip0 += 1 + ip1 = 0 + if ip0 > 255: + raise ValueError("Ran out of available local IPs") + + self.ips = ips + self.hosts = hosts + self.tb_hosts = tb_hosts + + def setup(self, verbose=False, auto_setup=False): + netmask = "255.255.255.252" + for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)): + command = "" + command += "sudo ifconfig bridge0 down\n" + for j in range(len(self.hosts)): + if i == j or (i, j) not in self.ips: + continue + for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]): + command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n" + command += f"sudo route change {peer} -interface {iface}\n" + if auto_setup: + print(f"Running auto setup for {h.ssh_hostname}") + command = command.strip().replace("\n", " ; ") + command = ["ssh", h.ssh_hostname, command] + log(verbose, shlex.join(command)) + run(command) + else: + msg = f"Setup for {h.ssh_hostname}" + print(msg) + print("=" * len(msg)) + print(command) + input("Enter to continue") + print() + + +def parse_hardware_ports(ports_string): + ports = {} + port_name = None + for l in ports_string.decode("utf-8").split("\n"): + if l.startswith("Hardware Port:"): + port_name = l.strip()[15:] + elif l.startswith("Device:"): + ports[port_name] = l.strip()[8:] + port_name = None + return ports + + +def extract_connectivity(hosts, verbose): + # Extract the current connectivity from the remote hosts + thunderbolt_connections = [] + for h in hosts: + log(verbose, "Getting connectivity from", h.ssh_hostname) + thunderbolt_connections.append( + json.loads( + run( + [ + "ssh", + h.ssh_hostname, + "system_profiler", + "SPThunderboltDataType", + "-json", + ], + capture_output=True, + ).stdout + ) + ) + interface_maps = [] + for h in hosts: + log(verbose, "Getting interface names from", h.ssh_hostname) + interface_maps.append( + parse_hardware_ports( + run( + [ + "ssh", + h.ssh_hostname, + "networksetup", + "-listallhardwareports", + ], + capture_output=True, + ).stdout + ) + ) + + # Parse the connectivity into some simple dataclasses + tb_hosts = [] + for c, iface_map in zip(thunderbolt_connections, interface_maps): + name = "" + ports = [] + for t in c["SPThunderboltDataType"]: + uuid = t.get("domain_uuid_key") + if uuid is None: + continue + name = t["device_name_key"] + tag = t["receptacle_1_tag"]["receptacle_id_key"] + items = t.get("_items", []) + connected_items = [item for item in items if "domain_uuid_key" in item] + connected_to = ( + connected_items[0]["domain_uuid_key"] if connected_items else None + ) + iface = iface_map[f"Thunderbolt {tag}"] + ports.append(ThunderboltPort(iface, uuid, connected_to)) + tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface))) + + # Create a reverse index to be able to map uuids to (host, port) quickly + uuid_reverse_index = {} + for i, h in enumerate(tb_hosts): + for j, p in enumerate(h.ports): + uuid_reverse_index[p.uuid] = (i, j) + + return tb_hosts, uuid_reverse_index + + +def make_connectivity_matrix(tb_hosts, uuid_reverse_index): + connectivity = [] + for i, h in enumerate(tb_hosts): + c = [0] * len(tb_hosts) + for p in h.ports: + if p.connected_to is not None: + j, _ = uuid_reverse_index[p.connected_to] + c[j] += 1 + connectivity.append(c) + return connectivity + + +def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index): + # Make ids per node + names = [] + for i in range(len(tb_hosts)): + n = "" + j = i + while True: + n += chr(97 + j % 26) + j //= 26 + if j == 0: + break + names.append(n) + + print("graph G {") + print(" node [shape=rectangle];") + for i, h in enumerate(hosts): + print(f' {names[i]} [label="{h.ssh_hostname}"];') + for i, h in enumerate(tb_hosts): + for p in h.ports: + if not p.connected_to: + continue + dst = uuid_reverse_index[p.connected_to] + if dst[0] < i: + continue + print(f" {names[i]} -- {names[dst[0]]}", end="") + print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]') + print("}") + + +def extract_rings(connectivity): + rings = [] + existing_rings = set() + num_nodes = len(connectivity) + + def dfs(start_node, node, path, visited): + path.append(node) + visited.add(node) + for j in range(num_nodes): + if connectivity[node][j] <= 0: + continue + if j == start_node: + yield path[:] + if j not in visited: + yield from dfs(start_node, j, path, visited) + path.pop() + visited.remove(node) + + for start in range(num_nodes): + for r in dfs(start, start, [], set()): + cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r))) + rkey = tuple(sorted(r)) + if rkey not in existing_rings: + rings.append((r, cnt)) + existing_rings.add(rkey) + + return sorted(rings, key=lambda x: -len(x[0])) + + +def check_valid_mesh(hosts, connectivity, strict=True): + num_nodes = len(connectivity) + for i in range(num_nodes): + for j in range(num_nodes): + if i == j: + continue + if connectivity[i][j] <= 0: + if strict: + log_error( + f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}" + ) + sys.exit(1) + else: + return False + return True + + +def check_ssh_connections(hosts): + results = [None] * len(hosts) + + def _check(hostname, i): + info = SSHInfo(False, False) + results[i] = info + + # Check for ssh + result = run( + [ + "ssh", + "-o", + "BatchMode=yes", + "-o", + "ConnectTimeout=5", + hostname, + "echo", + "success", + ], + stdout=DEVNULL, + stderr=DEVNULL, + ) + info.can_ssh = result.returncode == 0 + if not info.can_ssh: + return + + # Check for sudo + result = run( + [ + "ssh", + "-o", + "BatchMode=yes", + "-o", + "ConnectTimeout=5", + hostname, + "sudo", + "ls", + ], + stdout=DEVNULL, + stderr=DEVNULL, + ) + info.has_sudo = result.returncode == 0 + + threads = [ + threading.Thread(target=_check, args=(h.ssh_hostname, i)) + for i, h in enumerate(hosts) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + if not all(results): + log_error("Could not ssh to the following hosts:") + for i, h in enumerate(hosts): + if not results[i]: + log_error(" - ", h.ssh_hostname) + log_error() + log_error("Maybe they are not set-up for password-less ssh?") + sys.exit(1) + + return results + + +def prepare_ethernet_hostfile(args, hosts): + log(args.verbose, f"Preparing an ethernet hostfile") + add_ethernet_ips(hosts, args.verbose) + + hostfile = [] + for h in hosts: + hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips)) + + if args.output_hostfile: + with open(args.output_hostfile, "w") as f: + json.dump(hostfile, f, indent=4) + else: + print("Hostfile") + print("========") + print(json.dumps(hostfile, indent=4)) + + +def configure_ring(args, hosts, ips, ring): + log(args.verbose, "Prepare a ring hostfile") + ring, count = ring + hostfile = [] + for i, node in enumerate(ring): + h = hosts[node] + peer = ring[i - 1] + hostfile.append( + { + "ssh": h.ssh_hostname, + "ips": [ips.ips[node, peer][c][1] for c in range(count)], + "rdma": [], + } + ) + + ips.setup(verbose=args.verbose, auto_setup=args.auto_setup) + + if args.output_hostfile: + with open(args.output_hostfile, "w") as f: + json.dump(hostfile, f, indent=4) + else: + print("Hostfile") + print("========") + print(json.dumps(hostfile, indent=4)) + + +def configure_jaccl(args, hosts, ips): + log(args.verbose, "Prepare a jaccl hostfile") + add_ethernet_ips(hosts) + + hostfile = [] + for i, h in enumerate(hosts): + rdma = [] + for j in range(len(hosts)): + if i == j: + rdma.append(None) + else: + rdma.append(f"rdma_{ips.ips[i, j][0][0]}") + hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma}) + + ips.setup(verbose=args.verbose, auto_setup=args.auto_setup) + + if args.output_hostfile: + with open(args.output_hostfile, "w") as f: + json.dump(hostfile, f, indent=4) + else: + print("Hostfile") + print("========") + print(json.dumps(hostfile, indent=4)) + + +def prepare_tb_hostfile(args, hosts): + log(args.verbose, f"Preparing for communication over thunderbolt") + tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose) + + if args.dot: + tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index) + return + + ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index) + connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index) + + if args.backend is None: + rings = extract_rings(connectivity) + has_mesh = check_valid_mesh(hosts, connectivity, False) + has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts) + + if not has_ring and not has_mesh: + log_error("Neither thunderbolt mesh nor ring found.") + log_error("Perhaps run with --dot to generate a plot of the connectivity.") + sys.exit(1) + + elif has_ring: + configure_ring(args, hosts, ips, rings[0]) + + else: + configure_jaccl(args, hosts, ips) + + elif args.backend == "ring": + rings = extract_rings(connectivity) + has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts) + if not has_ring: + log_error("Could not find a full ring.") + if len(rings) > 0: + log_error("Rings found:") + for r in rings: + log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}") + sys.exit(1) + configure_ring(args, hosts, ips, rings[0]) + + elif args.backend == "jaccl": + check_valid_mesh(hosts, connectivity) + configure_jaccl(args, hosts, ips) + + +def main(): + parser = argparse.ArgumentParser( + description="Configure remote machines for use with MLX distributed" + ) + parser.add_argument( + "--verbose", action="store_true", help="Print debug messages in stdout" + ) + parser.add_argument( + "--hosts", default="127.0.0.1", help="A comma separated list of hosts" + ) + parser.add_argument("--hostfile", help="The file containing the hosts") + parser.add_argument( + "--over", + choices=["thunderbolt", "ethernet"], + default="thunderbolt", + help="What type of connectivity to configure", + required=True, + ) + parser.add_argument( + "--output-hostfile", help="If provided, save the hostfile to this path" + ) + parser.add_argument( + "--auto-setup", + "--no-auto-setup", + action=OptionalBoolAction, + nargs=0, + dest="auto_setup", + default=None, + ) + parser.add_argument( + "--dot", action="store_true", help="Output the topology in DOT format and exit" + ) + parser.add_argument( + "--backend", + choices=["ring", "jaccl"], + default=None, + help="Which distributed backend to configure", + ) + + # parser.add_argument( + # "--hostfile-only", action="store_true", help="If set only compute the hostfile" + # ) + args = parser.parse_args() + + if args.hostfile is not None: + hosts = parse_hostfile(parser, args.hostfile) + else: + hosts = parse_hostlist(parser, args.hosts, 1) + + # Check that we can ssh + log( + args.verbose, + f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}", + ) + check_ssh_connections(hosts) + + # Prepare a hostfile for communication over ethernet using the ips of the + # provided hostnames + if args.over == "ethernet": + prepare_ethernet_hostfile(args, hosts) + + # Configure the macs for communication over thunderbolt, both via RDMA and IP + else: + prepare_tb_hostfile(args, hosts) diff --git a/python/mlx/_distributed_utils/launch.py b/python/mlx/_distributed_utils/launch.py index 130e316db..ee0fe6264 100644 --- a/python/mlx/_distributed_utils/launch.py +++ b/python/mlx/_distributed_utils/launch.py @@ -45,11 +45,13 @@ class CommandProcess: class RemoteProcess(CommandProcess): - def __init__(self, rank, host, cwd, files, env, command): + def __init__(self, rank, host, python, cwd, files, env, command): is_local = host == "127.0.0.1" script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command) script_b64 = base64.b64encode(script.encode()).decode() - cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"' + cmd = ( + f'{python} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"' + ) if not is_local: cmd = f"ssh {host} '{cmd}'" @@ -309,7 +311,7 @@ def launch_ring(parser, hosts, args, command): _launch_with_io( RemoteProcess, [ - ((rank, h.ssh_hostname, cwd, files, env, command), {}) + ((rank, h.ssh_hostname, args.python, cwd, files, env, command), {}) for rank, h in enumerate(hosts) ], args.verbose, @@ -341,6 +343,7 @@ def launch_nccl(parser, hosts, args, command): ( rank, h.ssh_hostname, + args.python, cwd, {}, env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"], @@ -374,7 +377,7 @@ def launch_jaccl(parser, hosts, args, command): _launch_with_io( RemoteProcess, [ - ((rank, h.ssh_hostname, cwd, files, env, command), {}) + ((rank, h.ssh_hostname, args.python, cwd, files, env, command), {}) for rank, h in enumerate(hosts) ], args.verbose, @@ -503,11 +506,20 @@ def main(): default=12345, help="The port to use for the NCCL communication (only for nccl backend)", ) + parser.add_argument( + "--no-verify-script", + action="store_false", + dest="verify_script", + help="Do not verify that the script exists", + ) + parser.add_argument( + "--python", default=sys.executable, help="Use this python on the remote hosts" + ) args, rest = parser.parse_known_args() if args.print_python: - print(sys.executable) + print(args.python) return if len(rest) == 0: @@ -523,10 +535,10 @@ def main(): # Check if the script is a file and convert it to a full path if (script := Path(rest[0])).exists() and script.is_file(): - rest[0:1] = [sys.executable, str(script.resolve())] + rest[0:1] = [args.python, str(script.resolve())] elif (command := shutil.which(rest[0])) is not None: rest[0] = command - else: + elif args.verify_script: raise ValueError(f"Invalid script or command {rest[0]}") # Launch diff --git a/setup.py b/setup.py index 13d9f9dd6..0e86b4726 100644 --- a/setup.py +++ b/setup.py @@ -266,7 +266,7 @@ if __name__ == "__main__": entry_points = { "console_scripts": [ "mlx.launch = mlx._distributed_utils.launch:main", - # "mlx.distributed_config = mlx.distributed_run:distributed_config", + "mlx.distributed_config = mlx._distributed_utils.config:main", ] } install_requires = []