diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 0857e286e..7f45faf69 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -16,6 +16,7 @@ from dataclasses import dataclass from pathlib import Path from select import select from subprocess import PIPE, Popen, run +from typing import Optional @dataclass @@ -25,6 +26,95 @@ class Host: ips: list[str] +@dataclass +class ThunderboltPort: + iface: str + uuid: str + connected_to: Optional[str] + + +@dataclass +class ThunderboltHost: + name: str + ports: list[ThunderboltPort] + + +def parse_hardware_ports(ports_string): + ports = {} + port_name = None + for l in ports_string.decode("utf-8").split("\n"): + if l.startswith("Hardware Port:"): + port_name = l.strip()[15:] + elif l.startswith("Device:"): + ports[port_name] = l.strip()[8:] + port_name = None + return ports + + +def extract_rings(hosts, index): + def usable_port(i, j, used_ports): + return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None + + def dfs(start_node, node, path, visited, used_ports): + path.append(node) + visited.add(node) + for j, p in enumerate(hosts[node].ports): + if not usable_port(node, j, used_ports): + continue + next_node, _ = index[p.connected_to] + if next_node == start_node: + yield path[:] + if next_node not in visited: + yield from dfs(start_node, next_node, path, visited, used_ports) + path.pop() + visited.remove(node) + + # Concretize maps the found cycle to real thunderbolt ports. It also adds + # those ports to the used set so next cycles can't use them again. + def concretize(cycle, used_ports): + concrete_path = [] + for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]): + for j, p in enumerate(hosts[n1].ports): + if not usable_port(n1, j, used_ports): + continue + n2_hat, nj = index[p.connected_to] + if n2 == n2_hat: + concrete_path.append(((n1, j), (n2, nj))) + used_ports.add((n1, j)) + used_ports.add((n2, nj)) + break + if concrete_path[-1][0][0] != n1: + raise RuntimeError("Couldn't concretize the cycle") + return concrete_path + + # Normalize tries to ensure that the cycles have the same direction so we can + # use them together. We achieve this by selecting the direction such that + # the smallest rank hosts connect to larger rank hosts. + def normalize(path): + small_to_large = sum(1 for p in path if p[0][0] < p[1][0]) + if small_to_large > len(path) - small_to_large: + return path + else: + return [(p[1], p[0]) for p in path] + + rings = [] + used_ports = set() + for start_node in range(len(hosts)): + while True: + ring = [] + for r in dfs(start_node, start_node, [], set(), used_ports): + if len(r) > len(ring): + ring = r + # Break early since we won't find a bigger ring no matter what + if len(ring) == len(hosts): + break + if not ring: + break + rings.append(normalize(concretize(ring, used_ports))) + + return rings + + def positive_number(x): x = int(x) if x <= 0: @@ -43,6 +133,11 @@ def log_warning(*args, **kwargs): print("\033[33m[WARN]", *args, "\033[0m", **kwargs) +def log_error(*args, **kwargs): + kwargs["file"] = sys.stderr + print("\033[31m[ERROR]", *args, "\033[0m", **kwargs) + + def parse_hostfile(parser, hostfile): """Parse the json hostfile that contains both the hostnames to ssh into and the ips to communicate over when using the ring backend. @@ -77,6 +172,8 @@ def parse_hostfile(parser, hostfile): def parse_hostlist(parser, hostlist, repeats): hosts = [] for i, h in enumerate(hostlist.split(",")): + if h == "": + raise ValueError("Hostname cannot be empty") try: ipaddress.ip_address(h) ips = [h] @@ -271,6 +368,305 @@ def launch_mpi(parser, hosts, args, command): pass +def check_ssh_connections(hosts): + results = [False] * len(hosts) + + def _check(hostname, i): + result = run( + [ + "ssh", + "-o", + "BatchMode=yes", + "-o", + "ConnectTimeout=5", + hostname, + "echo", + "success", + ], + stdout=PIPE, + stderr=PIPE, + ) + results[i] = result.returncode == 0 + + threads = [ + threading.Thread(target=_check, args=(h.ssh_hostname, i)) + for i, h in enumerate(hosts) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + if not all(results): + log_error("Could not ssh to the following hosts:") + for i, h in enumerate(hosts): + if not results[i]: + log_error(" - ", h.ssh_hostname) + log_error() + log_error("Maybe they are not set-up for password-less ssh?") + sys.exit(1) + + +def prepare_tb_ring(args, hosts): + log( + args.verbose, + f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}", + ) + + # Check that we can ssh + check_ssh_connections(hosts) + if args.auto_setup and args.verbose: + log_warning( + "--auto-setup is requested which requires password-less sudo", + "on the remote hosts", + ) + + # Extract the current connectivity from the remote hosts + thunderbolt_connections = [] + for h in hosts: + log(args.verbose, "Getting connectivity from", h.ssh_hostname) + thunderbolt_connections.append( + json.loads( + run( + [ + "ssh", + h.ssh_hostname, + "system_profiler", + "SPThunderboltDataType", + "-json", + ], + capture_output=True, + ).stdout + ) + ) + interface_maps = [] + for h in hosts: + log(args.verbose, "Getting interface names from", h.ssh_hostname) + interface_maps.append( + parse_hardware_ports( + run( + [ + "ssh", + h.ssh_hostname, + "networksetup", + "-listallhardwareports", + ], + capture_output=True, + ).stdout + ) + ) + + # Parse the connectivity into some simple dataclasses + tb_hosts = [] + for c, iface_map in zip(thunderbolt_connections, interface_maps): + name = "" + ports = [] + for t in c["SPThunderboltDataType"]: + name = t["device_name_key"] + uuid = t["domain_uuid_key"] + tag = t["receptacle_1_tag"]["receptacle_id_key"] + if items := t.get("_items", []): + connected_to = items[0]["domain_uuid_key"] + else: + connected_to = None + iface = iface_map[f"Thunderbolt {tag}"] + ports.append(ThunderboltPort(iface, uuid, connected_to)) + tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface))) + + # Create a reverse index to be able to map uuids to (host, port) quickly + uuid_reverse_index = {} + for i, h in enumerate(tb_hosts): + for j, p in enumerate(h.ports): + uuid_reverse_index[p.uuid] = (i, j) + + # Find the rings by simply walking and marking visited (host, port) tuples + # and keeping the largest rings greedily. + log(args.verbose, "Extracting rings from the parsed connectivity") + rings = extract_rings(tb_hosts, uuid_reverse_index) + + # Just output a DOT graphical representation of the found rings + if args.dot: + names = [] + for i in range(len(tb_hosts)): + n = "" + j = i + while True: + n += chr(97 + j % 26) + j //= 26 + if j == 0: + break + names.append(n) + + print("graph G {") + print(" node [shape=rectangle];") + for i, h in enumerate(hosts): + print(f' {names[i]} [label="{h.ssh_hostname}"];') + for r in rings: + for (i, _), (j, _) in r: + print(f" {names[i]} -- {names[j]};") + print("}") + return + + # Assign IPs to each interface such that the interfaces can communicate + ips = {} + pairs = {} + expecting = set() + ip0 = 0 + ip1 = 0 + netmask = "255.255.255.252" + for r in rings: + for a, b in r: + ips[a] = f"192.168.{ip0}.{ip1 + 1}" + ips[b] = f"192.168.{ip0}.{ip1 + 2}" + pairs[a] = b + pairs[b] = a + expecting.add(b) + ip1 += 4 + if ip1 > 255: + ip0 += 1 + ip1 = 0 + if ip0 > 255: + raise ValueError("Ran out of available local IPs for the ring") + + # Create the hostfile + hostfile = [] + for i, h in enumerate(hosts): + host = { + "ssh": h.ssh_hostname, + "ips": [ + ips[i, j] + for j, p in enumerate(tb_hosts[i].ports) + if (i, j) in expecting + ], + } + hostfile.append(host) + + if not args.hostfile_only: + for i, h in enumerate(hosts): + command = "" + command += "sudo ifconfig bridge0 down\n" + for j, p in enumerate(tb_hosts[i].ports): + if (i, j) not in ips: + continue + iface = p.iface + ip = ips[i, j] + peer = ips[pairs[i, j]] + command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n" + command += f"sudo route change {peer} -interface {iface}\n" + if args.auto_setup: + print(f"Running auto setup for {h.ssh_hostname}") + command = command.strip().replace("\n", " && ") + command = ["ssh", h.ssh_hostname, command] + log(args.verbose, shlex.join(command)) + run(command) + else: + msg = f"Setup for {h.ssh_hostname}" + print(msg) + print("=" * len(msg)) + print(command) + input("Enter to continue") + print() + + if args.output_hostfile: + with open(args.output_hostfile, "w") as f: + json.dump(hostfile, f, indent=4) + else: + print("Hostfile") + print("========") + print(json.dumps(hostfile, indent=4)) + + +def prepare_hostfile(args, hosts): + log( + args.verbose, + f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}", + ) + + # Check that we can ssh + check_ssh_connections(hosts) + + # Get the ips for each host + for h in hosts: + log(args.verbose, "Getting the ip from", h.ssh_hostname) + h.ips.append( + run( + ["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"], + capture_output=True, + text=True, + ).stdout.strip() + ) + + hostfile = [] + for h in hosts: + hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips)) + + if args.output_hostfile: + with open(args.output_hostfile, "w") as f: + json.dump(hostfile, f, indent=4) + else: + print("Hostfile") + print("========") + print(json.dumps(hostfile, indent=4)) + + +def distributed_config(): + parser = argparse.ArgumentParser( + description="Configure remote machines for use with MLX distributed" + ) + parser.add_argument( + "--verbose", action="store_true", help="Print debug messages in stdout" + ) + parser.add_argument( + "--backend", + choices=["ring", "mpi"], + default="ring", + help="Which distributed backend to configure", + ) + parser.add_argument( + "--over", + choices=["thunderbolt", "ethernet"], + default="thunderbolt", + help="What type of connectivity to configure", + ) + parser.add_argument( + "--hosts", default="127.0.0.1", help="A comma separated list of hosts" + ) + parser.add_argument("--hostfile", help="The file containing the hosts") + parser.add_argument( + "--dot", action="store_true", help="Output the topology in DOT format and exit" + ) + parser.add_argument( + "--hostfile-only", action="store_true", help="If set only compute the hostfile" + ) + parser.add_argument( + "--output-hostfile", help="If provided, save the hostfile to this path" + ) + parser.add_argument( + "--auto-setup", + action="store_true", + help="If set we will attempt to automatically configure the machines via ssh", + ) + args = parser.parse_args() + + if args.backend == "mpi" and args.over == "thunderbolt": + raise ValueError( + ( + "The configuration of MPI over thunderbolt is " + "not supported yet by mlx.distributed_config" + ) + ) + + if args.hostfile is not None: + hosts = parse_hostfile(parser, args.hostfile) + else: + hosts = parse_hostlist(parser, args.hosts, 1) + + if args.over == "thunderbolt": + prepare_tb_ring(args, hosts) + else: + prepare_hostfile(args, hosts) + + def main(): parser = argparse.ArgumentParser(description="Launch an MLX distributed program") parser.add_argument( diff --git a/setup.py b/setup.py index 9296254ff..d4b5e15dd 100644 --- a/setup.py +++ b/setup.py @@ -194,7 +194,12 @@ if __name__ == "__main__": "typing_extensions", ], }, - entry_points={"console_scripts": ["mlx.launch = mlx.distributed_run:main"]}, + entry_points={ + "console_scripts": [ + "mlx.launch = mlx.distributed_run:main", + "mlx.distributed_config = mlx.distributed_run:distributed_config", + ] + }, ext_modules=[CMakeExtension("mlx.core")], cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, zip_safe=False,