Add mlx.distributed_config script (#1902)

This commit is contained in:
Angelos Katharopoulos 2025-02-28 11:16:39 -08:00 committed by GitHub
parent 89d327075f
commit 607181644f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 402 additions and 1 deletions

View File

@ -16,6 +16,7 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from select import select from select import select
from subprocess import PIPE, Popen, run from subprocess import PIPE, Popen, run
from typing import Optional
@dataclass @dataclass
@ -25,6 +26,95 @@ class Host:
ips: list[str] ips: list[str]
@dataclass
class ThunderboltPort:
iface: str
uuid: str
connected_to: Optional[str]
@dataclass
class ThunderboltHost:
name: str
ports: list[ThunderboltPort]
def parse_hardware_ports(ports_string):
ports = {}
port_name = None
for l in ports_string.decode("utf-8").split("\n"):
if l.startswith("Hardware Port:"):
port_name = l.strip()[15:]
elif l.startswith("Device:"):
ports[port_name] = l.strip()[8:]
port_name = None
return ports
def extract_rings(hosts, index):
def usable_port(i, j, used_ports):
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
def dfs(start_node, node, path, visited, used_ports):
path.append(node)
visited.add(node)
for j, p in enumerate(hosts[node].ports):
if not usable_port(node, j, used_ports):
continue
next_node, _ = index[p.connected_to]
if next_node == start_node:
yield path[:]
if next_node not in visited:
yield from dfs(start_node, next_node, path, visited, used_ports)
path.pop()
visited.remove(node)
# Concretize maps the found cycle to real thunderbolt ports. It also adds
# those ports to the used set so next cycles can't use them again.
def concretize(cycle, used_ports):
concrete_path = []
for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]):
for j, p in enumerate(hosts[n1].ports):
if not usable_port(n1, j, used_ports):
continue
n2_hat, nj = index[p.connected_to]
if n2 == n2_hat:
concrete_path.append(((n1, j), (n2, nj)))
used_ports.add((n1, j))
used_ports.add((n2, nj))
break
if concrete_path[-1][0][0] != n1:
raise RuntimeError("Couldn't concretize the cycle")
return concrete_path
# Normalize tries to ensure that the cycles have the same direction so we can
# use them together. We achieve this by selecting the direction such that
# the smallest rank hosts connect to larger rank hosts.
def normalize(path):
small_to_large = sum(1 for p in path if p[0][0] < p[1][0])
if small_to_large > len(path) - small_to_large:
return path
else:
return [(p[1], p[0]) for p in path]
rings = []
used_ports = set()
for start_node in range(len(hosts)):
while True:
ring = []
for r in dfs(start_node, start_node, [], set(), used_ports):
if len(r) > len(ring):
ring = r
# Break early since we won't find a bigger ring no matter what
if len(ring) == len(hosts):
break
if not ring:
break
rings.append(normalize(concretize(ring, used_ports)))
return rings
def positive_number(x): def positive_number(x):
x = int(x) x = int(x)
if x <= 0: if x <= 0:
@ -43,6 +133,11 @@ def log_warning(*args, **kwargs):
print("\033[33m[WARN]", *args, "\033[0m", **kwargs) print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
def log_error(*args, **kwargs):
kwargs["file"] = sys.stderr
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
def parse_hostfile(parser, hostfile): def parse_hostfile(parser, hostfile):
"""Parse the json hostfile that contains both the hostnames to ssh into and """Parse the json hostfile that contains both the hostnames to ssh into and
the ips to communicate over when using the ring backend. the ips to communicate over when using the ring backend.
@ -77,6 +172,8 @@ def parse_hostfile(parser, hostfile):
def parse_hostlist(parser, hostlist, repeats): def parse_hostlist(parser, hostlist, repeats):
hosts = [] hosts = []
for i, h in enumerate(hostlist.split(",")): for i, h in enumerate(hostlist.split(",")):
if h == "":
raise ValueError("Hostname cannot be empty")
try: try:
ipaddress.ip_address(h) ipaddress.ip_address(h)
ips = [h] ips = [h]
@ -271,6 +368,305 @@ def launch_mpi(parser, hosts, args, command):
pass pass
def check_ssh_connections(hosts):
results = [False] * len(hosts)
def _check(hostname, i):
result = run(
[
"ssh",
"-o",
"BatchMode=yes",
"-o",
"ConnectTimeout=5",
hostname,
"echo",
"success",
],
stdout=PIPE,
stderr=PIPE,
)
results[i] = result.returncode == 0
threads = [
threading.Thread(target=_check, args=(h.ssh_hostname, i))
for i, h in enumerate(hosts)
]
for t in threads:
t.start()
for t in threads:
t.join()
if not all(results):
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
log_error(" - ", h.ssh_hostname)
log_error()
log_error("Maybe they are not set-up for password-less ssh?")
sys.exit(1)
def prepare_tb_ring(args, hosts):
log(
args.verbose,
f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
if args.auto_setup and args.verbose:
log_warning(
"--auto-setup is requested which requires password-less sudo",
"on the remote hosts",
)
# Extract the current connectivity from the remote hosts
thunderbolt_connections = []
for h in hosts:
log(args.verbose, "Getting connectivity from", h.ssh_hostname)
thunderbolt_connections.append(
json.loads(
run(
[
"ssh",
h.ssh_hostname,
"system_profiler",
"SPThunderboltDataType",
"-json",
],
capture_output=True,
).stdout
)
)
interface_maps = []
for h in hosts:
log(args.verbose, "Getting interface names from", h.ssh_hostname)
interface_maps.append(
parse_hardware_ports(
run(
[
"ssh",
h.ssh_hostname,
"networksetup",
"-listallhardwareports",
],
capture_output=True,
).stdout
)
)
# Parse the connectivity into some simple dataclasses
tb_hosts = []
for c, iface_map in zip(thunderbolt_connections, interface_maps):
name = ""
ports = []
for t in c["SPThunderboltDataType"]:
name = t["device_name_key"]
uuid = t["domain_uuid_key"]
tag = t["receptacle_1_tag"]["receptacle_id_key"]
if items := t.get("_items", []):
connected_to = items[0]["domain_uuid_key"]
else:
connected_to = None
iface = iface_map[f"Thunderbolt {tag}"]
ports.append(ThunderboltPort(iface, uuid, connected_to))
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
# Create a reverse index to be able to map uuids to (host, port) quickly
uuid_reverse_index = {}
for i, h in enumerate(tb_hosts):
for j, p in enumerate(h.ports):
uuid_reverse_index[p.uuid] = (i, j)
# Find the rings by simply walking and marking visited (host, port) tuples
# and keeping the largest rings greedily.
log(args.verbose, "Extracting rings from the parsed connectivity")
rings = extract_rings(tb_hosts, uuid_reverse_index)
# Just output a DOT graphical representation of the found rings
if args.dot:
names = []
for i in range(len(tb_hosts)):
n = ""
j = i
while True:
n += chr(97 + j % 26)
j //= 26
if j == 0:
break
names.append(n)
print("graph G {")
print(" node [shape=rectangle];")
for i, h in enumerate(hosts):
print(f' {names[i]} [label="{h.ssh_hostname}"];')
for r in rings:
for (i, _), (j, _) in r:
print(f" {names[i]} -- {names[j]};")
print("}")
return
# Assign IPs to each interface such that the interfaces can communicate
ips = {}
pairs = {}
expecting = set()
ip0 = 0
ip1 = 0
netmask = "255.255.255.252"
for r in rings:
for a, b in r:
ips[a] = f"192.168.{ip0}.{ip1 + 1}"
ips[b] = f"192.168.{ip0}.{ip1 + 2}"
pairs[a] = b
pairs[b] = a
expecting.add(b)
ip1 += 4
if ip1 > 255:
ip0 += 1
ip1 = 0
if ip0 > 255:
raise ValueError("Ran out of available local IPs for the ring")
# Create the hostfile
hostfile = []
for i, h in enumerate(hosts):
host = {
"ssh": h.ssh_hostname,
"ips": [
ips[i, j]
for j, p in enumerate(tb_hosts[i].ports)
if (i, j) in expecting
],
}
hostfile.append(host)
if not args.hostfile_only:
for i, h in enumerate(hosts):
command = ""
command += "sudo ifconfig bridge0 down\n"
for j, p in enumerate(tb_hosts[i].ports):
if (i, j) not in ips:
continue
iface = p.iface
ip = ips[i, j]
peer = ips[pairs[i, j]]
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
command += f"sudo route change {peer} -interface {iface}\n"
if args.auto_setup:
print(f"Running auto setup for {h.ssh_hostname}")
command = command.strip().replace("\n", " && ")
command = ["ssh", h.ssh_hostname, command]
log(args.verbose, shlex.join(command))
run(command)
else:
msg = f"Setup for {h.ssh_hostname}"
print(msg)
print("=" * len(msg))
print(command)
input("Enter to continue")
print()
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def prepare_hostfile(args, hosts):
log(
args.verbose,
f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}",
)
# Check that we can ssh
check_ssh_connections(hosts)
# Get the ips for each host
for h in hosts:
log(args.verbose, "Getting the ip from", h.ssh_hostname)
h.ips.append(
run(
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
capture_output=True,
text=True,
).stdout.strip()
)
hostfile = []
for h in hosts:
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
json.dump(hostfile, f, indent=4)
else:
print("Hostfile")
print("========")
print(json.dumps(hostfile, indent=4))
def distributed_config():
parser = argparse.ArgumentParser(
description="Configure remote machines for use with MLX distributed"
)
parser.add_argument(
"--verbose", action="store_true", help="Print debug messages in stdout"
)
parser.add_argument(
"--backend",
choices=["ring", "mpi"],
default="ring",
help="Which distributed backend to configure",
)
parser.add_argument(
"--over",
choices=["thunderbolt", "ethernet"],
default="thunderbolt",
help="What type of connectivity to configure",
)
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--dot", action="store_true", help="Output the topology in DOT format and exit"
)
parser.add_argument(
"--hostfile-only", action="store_true", help="If set only compute the hostfile"
)
parser.add_argument(
"--output-hostfile", help="If provided, save the hostfile to this path"
)
parser.add_argument(
"--auto-setup",
action="store_true",
help="If set we will attempt to automatically configure the machines via ssh",
)
args = parser.parse_args()
if args.backend == "mpi" and args.over == "thunderbolt":
raise ValueError(
(
"The configuration of MPI over thunderbolt is "
"not supported yet by mlx.distributed_config"
)
)
if args.hostfile is not None:
hosts = parse_hostfile(parser, args.hostfile)
else:
hosts = parse_hostlist(parser, args.hosts, 1)
if args.over == "thunderbolt":
prepare_tb_ring(args, hosts)
else:
prepare_hostfile(args, hosts)
def main(): def main():
parser = argparse.ArgumentParser(description="Launch an MLX distributed program") parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
parser.add_argument( parser.add_argument(

View File

@ -194,7 +194,12 @@ if __name__ == "__main__":
"typing_extensions", "typing_extensions",
], ],
}, },
entry_points={"console_scripts": ["mlx.launch = mlx.distributed_run:main"]}, entry_points={
"console_scripts": [
"mlx.launch = mlx.distributed_run:main",
"mlx.distributed_config = mlx.distributed_run:distributed_config",
]
},
ext_modules=[CMakeExtension("mlx.core")], ext_modules=[CMakeExtension("mlx.core")],
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs}, cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
zip_safe=False, zip_safe=False,