mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
Add mlx.distributed_config script (#1902)
This commit is contained in:
parent
89d327075f
commit
607181644f
@ -16,6 +16,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from select import select
|
||||
from subprocess import PIPE, Popen, run
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -25,6 +26,95 @@ class Host:
|
||||
ips: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThunderboltPort:
|
||||
iface: str
|
||||
uuid: str
|
||||
connected_to: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThunderboltHost:
|
||||
name: str
|
||||
ports: list[ThunderboltPort]
|
||||
|
||||
|
||||
def parse_hardware_ports(ports_string):
|
||||
ports = {}
|
||||
port_name = None
|
||||
for l in ports_string.decode("utf-8").split("\n"):
|
||||
if l.startswith("Hardware Port:"):
|
||||
port_name = l.strip()[15:]
|
||||
elif l.startswith("Device:"):
|
||||
ports[port_name] = l.strip()[8:]
|
||||
port_name = None
|
||||
return ports
|
||||
|
||||
|
||||
def extract_rings(hosts, index):
|
||||
def usable_port(i, j, used_ports):
|
||||
return (i, j) not in used_ports and hosts[i].ports[j].connected_to is not None
|
||||
|
||||
def dfs(start_node, node, path, visited, used_ports):
|
||||
path.append(node)
|
||||
visited.add(node)
|
||||
for j, p in enumerate(hosts[node].ports):
|
||||
if not usable_port(node, j, used_ports):
|
||||
continue
|
||||
next_node, _ = index[p.connected_to]
|
||||
if next_node == start_node:
|
||||
yield path[:]
|
||||
if next_node not in visited:
|
||||
yield from dfs(start_node, next_node, path, visited, used_ports)
|
||||
path.pop()
|
||||
visited.remove(node)
|
||||
|
||||
# Concretize maps the found cycle to real thunderbolt ports. It also adds
|
||||
# those ports to the used set so next cycles can't use them again.
|
||||
def concretize(cycle, used_ports):
|
||||
concrete_path = []
|
||||
for n1, n2 in zip(cycle, cycle[1:] + cycle[:1]):
|
||||
for j, p in enumerate(hosts[n1].ports):
|
||||
if not usable_port(n1, j, used_ports):
|
||||
continue
|
||||
n2_hat, nj = index[p.connected_to]
|
||||
if n2 == n2_hat:
|
||||
concrete_path.append(((n1, j), (n2, nj)))
|
||||
used_ports.add((n1, j))
|
||||
used_ports.add((n2, nj))
|
||||
break
|
||||
if concrete_path[-1][0][0] != n1:
|
||||
raise RuntimeError("Couldn't concretize the cycle")
|
||||
return concrete_path
|
||||
|
||||
# Normalize tries to ensure that the cycles have the same direction so we can
|
||||
# use them together. We achieve this by selecting the direction such that
|
||||
# the smallest rank hosts connect to larger rank hosts.
|
||||
def normalize(path):
|
||||
small_to_large = sum(1 for p in path if p[0][0] < p[1][0])
|
||||
if small_to_large > len(path) - small_to_large:
|
||||
return path
|
||||
else:
|
||||
return [(p[1], p[0]) for p in path]
|
||||
|
||||
rings = []
|
||||
used_ports = set()
|
||||
for start_node in range(len(hosts)):
|
||||
while True:
|
||||
ring = []
|
||||
for r in dfs(start_node, start_node, [], set(), used_ports):
|
||||
if len(r) > len(ring):
|
||||
ring = r
|
||||
# Break early since we won't find a bigger ring no matter what
|
||||
if len(ring) == len(hosts):
|
||||
break
|
||||
if not ring:
|
||||
break
|
||||
rings.append(normalize(concretize(ring, used_ports)))
|
||||
|
||||
return rings
|
||||
|
||||
|
||||
def positive_number(x):
|
||||
x = int(x)
|
||||
if x <= 0:
|
||||
@ -43,6 +133,11 @@ def log_warning(*args, **kwargs):
|
||||
print("\033[33m[WARN]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def log_error(*args, **kwargs):
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[31m[ERROR]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
def parse_hostfile(parser, hostfile):
|
||||
"""Parse the json hostfile that contains both the hostnames to ssh into and
|
||||
the ips to communicate over when using the ring backend.
|
||||
@ -77,6 +172,8 @@ def parse_hostfile(parser, hostfile):
|
||||
def parse_hostlist(parser, hostlist, repeats):
|
||||
hosts = []
|
||||
for i, h in enumerate(hostlist.split(",")):
|
||||
if h == "":
|
||||
raise ValueError("Hostname cannot be empty")
|
||||
try:
|
||||
ipaddress.ip_address(h)
|
||||
ips = [h]
|
||||
@ -271,6 +368,305 @@ def launch_mpi(parser, hosts, args, command):
|
||||
pass
|
||||
|
||||
|
||||
def check_ssh_connections(hosts):
|
||||
results = [False] * len(hosts)
|
||||
|
||||
def _check(hostname, i):
|
||||
result = run(
|
||||
[
|
||||
"ssh",
|
||||
"-o",
|
||||
"BatchMode=yes",
|
||||
"-o",
|
||||
"ConnectTimeout=5",
|
||||
hostname,
|
||||
"echo",
|
||||
"success",
|
||||
],
|
||||
stdout=PIPE,
|
||||
stderr=PIPE,
|
||||
)
|
||||
results[i] = result.returncode == 0
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=_check, args=(h.ssh_hostname, i))
|
||||
for i, h in enumerate(hosts)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
if not all(results):
|
||||
log_error("Could not ssh to the following hosts:")
|
||||
for i, h in enumerate(hosts):
|
||||
if not results[i]:
|
||||
log_error(" - ", h.ssh_hostname)
|
||||
log_error()
|
||||
log_error("Maybe they are not set-up for password-less ssh?")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def prepare_tb_ring(args, hosts):
|
||||
log(
|
||||
args.verbose,
|
||||
f"Preparing a thunderbolt ring for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||
)
|
||||
|
||||
# Check that we can ssh
|
||||
check_ssh_connections(hosts)
|
||||
if args.auto_setup and args.verbose:
|
||||
log_warning(
|
||||
"--auto-setup is requested which requires password-less sudo",
|
||||
"on the remote hosts",
|
||||
)
|
||||
|
||||
# Extract the current connectivity from the remote hosts
|
||||
thunderbolt_connections = []
|
||||
for h in hosts:
|
||||
log(args.verbose, "Getting connectivity from", h.ssh_hostname)
|
||||
thunderbolt_connections.append(
|
||||
json.loads(
|
||||
run(
|
||||
[
|
||||
"ssh",
|
||||
h.ssh_hostname,
|
||||
"system_profiler",
|
||||
"SPThunderboltDataType",
|
||||
"-json",
|
||||
],
|
||||
capture_output=True,
|
||||
).stdout
|
||||
)
|
||||
)
|
||||
interface_maps = []
|
||||
for h in hosts:
|
||||
log(args.verbose, "Getting interface names from", h.ssh_hostname)
|
||||
interface_maps.append(
|
||||
parse_hardware_ports(
|
||||
run(
|
||||
[
|
||||
"ssh",
|
||||
h.ssh_hostname,
|
||||
"networksetup",
|
||||
"-listallhardwareports",
|
||||
],
|
||||
capture_output=True,
|
||||
).stdout
|
||||
)
|
||||
)
|
||||
|
||||
# Parse the connectivity into some simple dataclasses
|
||||
tb_hosts = []
|
||||
for c, iface_map in zip(thunderbolt_connections, interface_maps):
|
||||
name = ""
|
||||
ports = []
|
||||
for t in c["SPThunderboltDataType"]:
|
||||
name = t["device_name_key"]
|
||||
uuid = t["domain_uuid_key"]
|
||||
tag = t["receptacle_1_tag"]["receptacle_id_key"]
|
||||
if items := t.get("_items", []):
|
||||
connected_to = items[0]["domain_uuid_key"]
|
||||
else:
|
||||
connected_to = None
|
||||
iface = iface_map[f"Thunderbolt {tag}"]
|
||||
ports.append(ThunderboltPort(iface, uuid, connected_to))
|
||||
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
|
||||
|
||||
# Create a reverse index to be able to map uuids to (host, port) quickly
|
||||
uuid_reverse_index = {}
|
||||
for i, h in enumerate(tb_hosts):
|
||||
for j, p in enumerate(h.ports):
|
||||
uuid_reverse_index[p.uuid] = (i, j)
|
||||
|
||||
# Find the rings by simply walking and marking visited (host, port) tuples
|
||||
# and keeping the largest rings greedily.
|
||||
log(args.verbose, "Extracting rings from the parsed connectivity")
|
||||
rings = extract_rings(tb_hosts, uuid_reverse_index)
|
||||
|
||||
# Just output a DOT graphical representation of the found rings
|
||||
if args.dot:
|
||||
names = []
|
||||
for i in range(len(tb_hosts)):
|
||||
n = ""
|
||||
j = i
|
||||
while True:
|
||||
n += chr(97 + j % 26)
|
||||
j //= 26
|
||||
if j == 0:
|
||||
break
|
||||
names.append(n)
|
||||
|
||||
print("graph G {")
|
||||
print(" node [shape=rectangle];")
|
||||
for i, h in enumerate(hosts):
|
||||
print(f' {names[i]} [label="{h.ssh_hostname}"];')
|
||||
for r in rings:
|
||||
for (i, _), (j, _) in r:
|
||||
print(f" {names[i]} -- {names[j]};")
|
||||
print("}")
|
||||
return
|
||||
|
||||
# Assign IPs to each interface such that the interfaces can communicate
|
||||
ips = {}
|
||||
pairs = {}
|
||||
expecting = set()
|
||||
ip0 = 0
|
||||
ip1 = 0
|
||||
netmask = "255.255.255.252"
|
||||
for r in rings:
|
||||
for a, b in r:
|
||||
ips[a] = f"192.168.{ip0}.{ip1 + 1}"
|
||||
ips[b] = f"192.168.{ip0}.{ip1 + 2}"
|
||||
pairs[a] = b
|
||||
pairs[b] = a
|
||||
expecting.add(b)
|
||||
ip1 += 4
|
||||
if ip1 > 255:
|
||||
ip0 += 1
|
||||
ip1 = 0
|
||||
if ip0 > 255:
|
||||
raise ValueError("Ran out of available local IPs for the ring")
|
||||
|
||||
# Create the hostfile
|
||||
hostfile = []
|
||||
for i, h in enumerate(hosts):
|
||||
host = {
|
||||
"ssh": h.ssh_hostname,
|
||||
"ips": [
|
||||
ips[i, j]
|
||||
for j, p in enumerate(tb_hosts[i].ports)
|
||||
if (i, j) in expecting
|
||||
],
|
||||
}
|
||||
hostfile.append(host)
|
||||
|
||||
if not args.hostfile_only:
|
||||
for i, h in enumerate(hosts):
|
||||
command = ""
|
||||
command += "sudo ifconfig bridge0 down\n"
|
||||
for j, p in enumerate(tb_hosts[i].ports):
|
||||
if (i, j) not in ips:
|
||||
continue
|
||||
iface = p.iface
|
||||
ip = ips[i, j]
|
||||
peer = ips[pairs[i, j]]
|
||||
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
|
||||
command += f"sudo route change {peer} -interface {iface}\n"
|
||||
if args.auto_setup:
|
||||
print(f"Running auto setup for {h.ssh_hostname}")
|
||||
command = command.strip().replace("\n", " && ")
|
||||
command = ["ssh", h.ssh_hostname, command]
|
||||
log(args.verbose, shlex.join(command))
|
||||
run(command)
|
||||
else:
|
||||
msg = f"Setup for {h.ssh_hostname}"
|
||||
print(msg)
|
||||
print("=" * len(msg))
|
||||
print(command)
|
||||
input("Enter to continue")
|
||||
print()
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
json.dump(hostfile, f, indent=4)
|
||||
else:
|
||||
print("Hostfile")
|
||||
print("========")
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def prepare_hostfile(args, hosts):
|
||||
log(
|
||||
args.verbose,
|
||||
f"Preparing an ethernet hostfile for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||
)
|
||||
|
||||
# Check that we can ssh
|
||||
check_ssh_connections(hosts)
|
||||
|
||||
# Get the ips for each host
|
||||
for h in hosts:
|
||||
log(args.verbose, "Getting the ip from", h.ssh_hostname)
|
||||
h.ips.append(
|
||||
run(
|
||||
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
)
|
||||
|
||||
hostfile = []
|
||||
for h in hosts:
|
||||
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
json.dump(hostfile, f, indent=4)
|
||||
else:
|
||||
print("Hostfile")
|
||||
print("========")
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def distributed_config():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Configure remote machines for use with MLX distributed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi"],
|
||||
default="ring",
|
||||
help="Which distributed backend to configure",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--over",
|
||||
choices=["thunderbolt", "ethernet"],
|
||||
default="thunderbolt",
|
||||
help="What type of connectivity to configure",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||
)
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--dot", action="store_true", help="Output the topology in DOT format and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hostfile-only", action="store_true", help="If set only compute the hostfile"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-hostfile", help="If provided, save the hostfile to this path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-setup",
|
||||
action="store_true",
|
||||
help="If set we will attempt to automatically configure the machines via ssh",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.backend == "mpi" and args.over == "thunderbolt":
|
||||
raise ValueError(
|
||||
(
|
||||
"The configuration of MPI over thunderbolt is "
|
||||
"not supported yet by mlx.distributed_config"
|
||||
)
|
||||
)
|
||||
|
||||
if args.hostfile is not None:
|
||||
hosts = parse_hostfile(parser, args.hostfile)
|
||||
else:
|
||||
hosts = parse_hostlist(parser, args.hosts, 1)
|
||||
|
||||
if args.over == "thunderbolt":
|
||||
prepare_tb_ring(args, hosts)
|
||||
else:
|
||||
prepare_hostfile(args, hosts)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Launch an MLX distributed program")
|
||||
parser.add_argument(
|
||||
|
7
setup.py
7
setup.py
@ -194,7 +194,12 @@ if __name__ == "__main__":
|
||||
"typing_extensions",
|
||||
],
|
||||
},
|
||||
entry_points={"console_scripts": ["mlx.launch = mlx.distributed_run:main"]},
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
"mlx.launch = mlx.distributed_run:main",
|
||||
"mlx.distributed_config = mlx.distributed_run:distributed_config",
|
||||
]
|
||||
},
|
||||
ext_modules=[CMakeExtension("mlx.core")],
|
||||
cmdclass={"build_ext": CMakeBuild, "generate_stubs": GenerateStubs},
|
||||
zip_safe=False,
|
||||
|
Loading…
Reference in New Issue
Block a user