mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Refactor distributed config
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2025 Apple Inc.
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
@@ -16,6 +17,14 @@ class Host:
|
|||||||
rdma: list[Optional[str]]
|
rdma: list[Optional[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class OptionalBoolAction(argparse.Action):
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
if option_string.startswith("--no-"):
|
||||||
|
setattr(namespace, self.dest, False)
|
||||||
|
else:
|
||||||
|
setattr(namespace, self.dest, True)
|
||||||
|
|
||||||
|
|
||||||
def positive_number(x):
|
def positive_number(x):
|
||||||
x = int(x)
|
x = int(x)
|
||||||
if x <= 0:
|
if x <= 0:
|
||||||
@@ -50,7 +59,7 @@ def parse_hostlist(parser, hostlist, repeats):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
ips = []
|
ips = []
|
||||||
for i in range(repeats):
|
for i in range(repeats):
|
||||||
hosts.append(Host(i, h, ips))
|
hosts.append(Host(i, h, ips, []))
|
||||||
return hosts
|
return hosts
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,526 @@
|
|||||||
|
# Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import shlex
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from subprocess import DEVNULL, run
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
from .common import (
|
||||||
|
Host,
|
||||||
|
OptionalBoolAction,
|
||||||
|
log,
|
||||||
|
log_error,
|
||||||
|
parse_hostfile,
|
||||||
|
parse_hostlist,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SSHInfo:
|
||||||
|
can_ssh: bool
|
||||||
|
has_sudo: bool
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.can_ssh
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThunderboltPort:
|
||||||
|
iface: str
|
||||||
|
uuid: str
|
||||||
|
connected_to: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThunderboltHost:
|
||||||
|
name: str
|
||||||
|
ports: list[ThunderboltPort]
|
||||||
|
|
||||||
|
|
||||||
|
def add_ethernet_ips(hosts, verbose=False):
|
||||||
|
# Get the ips for each host
|
||||||
|
for h in hosts:
|
||||||
|
log(verbose, "Getting the ip from", h.ssh_hostname)
|
||||||
|
h.ips.append(
|
||||||
|
run(
|
||||||
|
["ssh", h.ssh_hostname, "ipconfig", "getifaddr", "en0"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
).stdout.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IPConfigurator:
|
||||||
|
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
|
||||||
|
assigned = set()
|
||||||
|
ips = defaultdict(list)
|
||||||
|
ip0 = 0
|
||||||
|
ip1 = 0
|
||||||
|
for src_node, h in enumerate(tb_hosts):
|
||||||
|
for src_port, p in enumerate(h.ports):
|
||||||
|
if not p.connected_to:
|
||||||
|
continue
|
||||||
|
if (src_node, src_port) in assigned:
|
||||||
|
continue
|
||||||
|
|
||||||
|
dst_node, dst_port = uuid_reverse_index[p.connected_to]
|
||||||
|
|
||||||
|
ip_src = f"192.168.{ip0}.{ip1 + 1}"
|
||||||
|
ip_dst = f"192.168.{ip0}.{ip1 + 2}"
|
||||||
|
iface_src = p.iface
|
||||||
|
iface_dst = tb_hosts[dst_node].ports[dst_port].iface
|
||||||
|
|
||||||
|
ips[src_node, dst_node].append((iface_src, ip_src))
|
||||||
|
ips[dst_node, src_node].append((iface_dst, ip_dst))
|
||||||
|
|
||||||
|
assigned.add((src_node, src_port))
|
||||||
|
assigned.add((dst_node, dst_port))
|
||||||
|
|
||||||
|
ip1 += 4
|
||||||
|
if ip1 > 255:
|
||||||
|
ip0 += 1
|
||||||
|
ip1 = 0
|
||||||
|
if ip0 > 255:
|
||||||
|
raise ValueError("Ran out of available local IPs")
|
||||||
|
|
||||||
|
self.ips = ips
|
||||||
|
self.hosts = hosts
|
||||||
|
self.tb_hosts = tb_hosts
|
||||||
|
|
||||||
|
def setup(self, verbose=False, auto_setup=False):
|
||||||
|
netmask = "255.255.255.252"
|
||||||
|
for i, (h, th) in enumerate(zip(self.hosts, self.tb_hosts)):
|
||||||
|
command = ""
|
||||||
|
command += "sudo ifconfig bridge0 down\n"
|
||||||
|
for j in range(len(self.hosts)):
|
||||||
|
if i == j or (i, j) not in self.ips:
|
||||||
|
continue
|
||||||
|
for (iface, ip), (_, peer) in zip(self.ips[i, j], self.ips[j, i]):
|
||||||
|
command += f"sudo ifconfig {iface} inet {ip} netmask {netmask}\n"
|
||||||
|
command += f"sudo route change {peer} -interface {iface}\n"
|
||||||
|
if auto_setup:
|
||||||
|
print(f"Running auto setup for {h.ssh_hostname}")
|
||||||
|
command = command.strip().replace("\n", " ; ")
|
||||||
|
command = ["ssh", h.ssh_hostname, command]
|
||||||
|
log(verbose, shlex.join(command))
|
||||||
|
run(command)
|
||||||
|
else:
|
||||||
|
msg = f"Setup for {h.ssh_hostname}"
|
||||||
|
print(msg)
|
||||||
|
print("=" * len(msg))
|
||||||
|
print(command)
|
||||||
|
input("Enter to continue")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hardware_ports(ports_string):
|
||||||
|
ports = {}
|
||||||
|
port_name = None
|
||||||
|
for l in ports_string.decode("utf-8").split("\n"):
|
||||||
|
if l.startswith("Hardware Port:"):
|
||||||
|
port_name = l.strip()[15:]
|
||||||
|
elif l.startswith("Device:"):
|
||||||
|
ports[port_name] = l.strip()[8:]
|
||||||
|
port_name = None
|
||||||
|
return ports
|
||||||
|
|
||||||
|
|
||||||
|
def extract_connectivity(hosts, verbose):
|
||||||
|
# Extract the current connectivity from the remote hosts
|
||||||
|
thunderbolt_connections = []
|
||||||
|
for h in hosts:
|
||||||
|
log(verbose, "Getting connectivity from", h.ssh_hostname)
|
||||||
|
thunderbolt_connections.append(
|
||||||
|
json.loads(
|
||||||
|
run(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
h.ssh_hostname,
|
||||||
|
"system_profiler",
|
||||||
|
"SPThunderboltDataType",
|
||||||
|
"-json",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
).stdout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
interface_maps = []
|
||||||
|
for h in hosts:
|
||||||
|
log(verbose, "Getting interface names from", h.ssh_hostname)
|
||||||
|
interface_maps.append(
|
||||||
|
parse_hardware_ports(
|
||||||
|
run(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
h.ssh_hostname,
|
||||||
|
"networksetup",
|
||||||
|
"-listallhardwareports",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
).stdout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the connectivity into some simple dataclasses
|
||||||
|
tb_hosts = []
|
||||||
|
for c, iface_map in zip(thunderbolt_connections, interface_maps):
|
||||||
|
name = ""
|
||||||
|
ports = []
|
||||||
|
for t in c["SPThunderboltDataType"]:
|
||||||
|
uuid = t.get("domain_uuid_key")
|
||||||
|
if uuid is None:
|
||||||
|
continue
|
||||||
|
name = t["device_name_key"]
|
||||||
|
tag = t["receptacle_1_tag"]["receptacle_id_key"]
|
||||||
|
items = t.get("_items", [])
|
||||||
|
connected_items = [item for item in items if "domain_uuid_key" in item]
|
||||||
|
connected_to = (
|
||||||
|
connected_items[0]["domain_uuid_key"] if connected_items else None
|
||||||
|
)
|
||||||
|
iface = iface_map[f"Thunderbolt {tag}"]
|
||||||
|
ports.append(ThunderboltPort(iface, uuid, connected_to))
|
||||||
|
tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))
|
||||||
|
|
||||||
|
# Create a reverse index to be able to map uuids to (host, port) quickly
|
||||||
|
uuid_reverse_index = {}
|
||||||
|
for i, h in enumerate(tb_hosts):
|
||||||
|
for j, p in enumerate(h.ports):
|
||||||
|
uuid_reverse_index[p.uuid] = (i, j)
|
||||||
|
|
||||||
|
return tb_hosts, uuid_reverse_index
|
||||||
|
|
||||||
|
|
||||||
|
def make_connectivity_matrix(tb_hosts, uuid_reverse_index):
|
||||||
|
connectivity = []
|
||||||
|
for i, h in enumerate(tb_hosts):
|
||||||
|
c = [0] * len(tb_hosts)
|
||||||
|
for p in h.ports:
|
||||||
|
if p.connected_to is not None:
|
||||||
|
j, _ = uuid_reverse_index[p.connected_to]
|
||||||
|
c[j] += 1
|
||||||
|
connectivity.append(c)
|
||||||
|
return connectivity
|
||||||
|
|
||||||
|
|
||||||
|
def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
|
||||||
|
# Make ids per node
|
||||||
|
names = []
|
||||||
|
for i in range(len(tb_hosts)):
|
||||||
|
n = ""
|
||||||
|
j = i
|
||||||
|
while True:
|
||||||
|
n += chr(97 + j % 26)
|
||||||
|
j //= 26
|
||||||
|
if j == 0:
|
||||||
|
break
|
||||||
|
names.append(n)
|
||||||
|
|
||||||
|
print("graph G {")
|
||||||
|
print(" node [shape=rectangle];")
|
||||||
|
for i, h in enumerate(hosts):
|
||||||
|
print(f' {names[i]} [label="{h.ssh_hostname}"];')
|
||||||
|
for i, h in enumerate(tb_hosts):
|
||||||
|
for p in h.ports:
|
||||||
|
if not p.connected_to:
|
||||||
|
continue
|
||||||
|
dst = uuid_reverse_index[p.connected_to]
|
||||||
|
if dst[0] < i:
|
||||||
|
continue
|
||||||
|
print(f" {names[i]} -- {names[dst[0]]}", end="")
|
||||||
|
print(f' [label="{p.iface}/{tb_hosts[dst[0]].ports[dst[1]].iface}"]')
|
||||||
|
print("}")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_rings(connectivity):
|
||||||
|
rings = []
|
||||||
|
existing_rings = set()
|
||||||
|
num_nodes = len(connectivity)
|
||||||
|
|
||||||
|
def dfs(start_node, node, path, visited):
|
||||||
|
path.append(node)
|
||||||
|
visited.add(node)
|
||||||
|
for j in range(num_nodes):
|
||||||
|
if connectivity[node][j] <= 0:
|
||||||
|
continue
|
||||||
|
if j == start_node:
|
||||||
|
yield path[:]
|
||||||
|
if j not in visited:
|
||||||
|
yield from dfs(start_node, j, path, visited)
|
||||||
|
path.pop()
|
||||||
|
visited.remove(node)
|
||||||
|
|
||||||
|
for start in range(num_nodes):
|
||||||
|
for r in dfs(start, start, [], set()):
|
||||||
|
cnt = min(connectivity[r[i]][r[(i + 1) % len(r)]] for i in range(len(r)))
|
||||||
|
rkey = tuple(sorted(r))
|
||||||
|
if rkey not in existing_rings:
|
||||||
|
rings.append((r, cnt))
|
||||||
|
existing_rings.add(rkey)
|
||||||
|
|
||||||
|
return sorted(rings, key=lambda x: -len(x[0]))
|
||||||
|
|
||||||
|
|
||||||
|
def check_valid_mesh(hosts, connectivity, strict=True):
|
||||||
|
num_nodes = len(connectivity)
|
||||||
|
for i in range(num_nodes):
|
||||||
|
for j in range(num_nodes):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
if connectivity[i][j] <= 0:
|
||||||
|
if strict:
|
||||||
|
log_error(
|
||||||
|
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def check_ssh_connections(hosts):
|
||||||
|
results = [None] * len(hosts)
|
||||||
|
|
||||||
|
def _check(hostname, i):
|
||||||
|
info = SSHInfo(False, False)
|
||||||
|
results[i] = info
|
||||||
|
|
||||||
|
# Check for ssh
|
||||||
|
result = run(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
"-o",
|
||||||
|
"BatchMode=yes",
|
||||||
|
"-o",
|
||||||
|
"ConnectTimeout=5",
|
||||||
|
hostname,
|
||||||
|
"echo",
|
||||||
|
"success",
|
||||||
|
],
|
||||||
|
stdout=DEVNULL,
|
||||||
|
stderr=DEVNULL,
|
||||||
|
)
|
||||||
|
info.can_ssh = result.returncode == 0
|
||||||
|
if not info.can_ssh:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for sudo
|
||||||
|
result = run(
|
||||||
|
[
|
||||||
|
"ssh",
|
||||||
|
"-o",
|
||||||
|
"BatchMode=yes",
|
||||||
|
"-o",
|
||||||
|
"ConnectTimeout=5",
|
||||||
|
hostname,
|
||||||
|
"sudo",
|
||||||
|
"ls",
|
||||||
|
],
|
||||||
|
stdout=DEVNULL,
|
||||||
|
stderr=DEVNULL,
|
||||||
|
)
|
||||||
|
info.has_sudo = result.returncode == 0
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=_check, args=(h.ssh_hostname, i))
|
||||||
|
for i, h in enumerate(hosts)
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
if not all(results):
|
||||||
|
log_error("Could not ssh to the following hosts:")
|
||||||
|
for i, h in enumerate(hosts):
|
||||||
|
if not results[i]:
|
||||||
|
log_error(" - ", h.ssh_hostname)
|
||||||
|
log_error()
|
||||||
|
log_error("Maybe they are not set-up for password-less ssh?")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_ethernet_hostfile(args, hosts):
|
||||||
|
log(args.verbose, f"Preparing an ethernet hostfile")
|
||||||
|
add_ethernet_ips(hosts, args.verbose)
|
||||||
|
|
||||||
|
hostfile = []
|
||||||
|
for h in hosts:
|
||||||
|
hostfile.append(dict(ssh=h.ssh_hostname, ips=h.ips))
|
||||||
|
|
||||||
|
if args.output_hostfile:
|
||||||
|
with open(args.output_hostfile, "w") as f:
|
||||||
|
json.dump(hostfile, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Hostfile")
|
||||||
|
print("========")
|
||||||
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
def configure_ring(args, hosts, ips, ring):
|
||||||
|
log(args.verbose, "Prepare a ring hostfile")
|
||||||
|
ring, count = ring
|
||||||
|
hostfile = []
|
||||||
|
for i, node in enumerate(ring):
|
||||||
|
h = hosts[node]
|
||||||
|
peer = ring[i - 1]
|
||||||
|
hostfile.append(
|
||||||
|
{
|
||||||
|
"ssh": h.ssh_hostname,
|
||||||
|
"ips": [ips.ips[node, peer][c][1] for c in range(count)],
|
||||||
|
"rdma": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
|
||||||
|
|
||||||
|
if args.output_hostfile:
|
||||||
|
with open(args.output_hostfile, "w") as f:
|
||||||
|
json.dump(hostfile, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Hostfile")
|
||||||
|
print("========")
|
||||||
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
def configure_jaccl(args, hosts, ips):
|
||||||
|
log(args.verbose, "Prepare a jaccl hostfile")
|
||||||
|
add_ethernet_ips(hosts)
|
||||||
|
|
||||||
|
hostfile = []
|
||||||
|
for i, h in enumerate(hosts):
|
||||||
|
rdma = []
|
||||||
|
for j in range(len(hosts)):
|
||||||
|
if i == j:
|
||||||
|
rdma.append(None)
|
||||||
|
else:
|
||||||
|
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
|
||||||
|
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
|
||||||
|
|
||||||
|
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
|
||||||
|
|
||||||
|
if args.output_hostfile:
|
||||||
|
with open(args.output_hostfile, "w") as f:
|
||||||
|
json.dump(hostfile, f, indent=4)
|
||||||
|
else:
|
||||||
|
print("Hostfile")
|
||||||
|
print("========")
|
||||||
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tb_hostfile(args, hosts):
|
||||||
|
log(args.verbose, f"Preparing for communication over thunderbolt")
|
||||||
|
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
|
||||||
|
|
||||||
|
if args.dot:
|
||||||
|
tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index)
|
||||||
|
return
|
||||||
|
|
||||||
|
ips = IPConfigurator(hosts, tb_hosts, uuid_reverse_index)
|
||||||
|
connectivity = make_connectivity_matrix(tb_hosts, uuid_reverse_index)
|
||||||
|
|
||||||
|
if args.backend is None:
|
||||||
|
rings = extract_rings(connectivity)
|
||||||
|
has_mesh = check_valid_mesh(hosts, connectivity, False)
|
||||||
|
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
||||||
|
|
||||||
|
if not has_ring and not has_mesh:
|
||||||
|
log_error("Neither thunderbolt mesh nor ring found.")
|
||||||
|
log_error("Perhaps run with --dot to generate a plot of the connectivity.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
elif has_ring:
|
||||||
|
configure_ring(args, hosts, ips, rings[0])
|
||||||
|
|
||||||
|
else:
|
||||||
|
configure_jaccl(args, hosts, ips)
|
||||||
|
|
||||||
|
elif args.backend == "ring":
|
||||||
|
rings = extract_rings(connectivity)
|
||||||
|
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
||||||
|
if not has_ring:
|
||||||
|
log_error("Could not find a full ring.")
|
||||||
|
if len(rings) > 0:
|
||||||
|
log_error("Rings found:")
|
||||||
|
for r in rings:
|
||||||
|
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
|
||||||
|
sys.exit(1)
|
||||||
|
configure_ring(args, hosts, ips, rings[0])
|
||||||
|
|
||||||
|
elif args.backend == "jaccl":
|
||||||
|
check_valid_mesh(hosts, connectivity)
|
||||||
|
configure_jaccl(args, hosts, ips)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Configure remote machines for use with MLX distributed"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", action="store_true", help="Print debug messages in stdout"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
|
||||||
|
)
|
||||||
|
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||||
|
parser.add_argument(
|
||||||
|
"--over",
|
||||||
|
choices=["thunderbolt", "ethernet"],
|
||||||
|
default="thunderbolt",
|
||||||
|
help="What type of connectivity to configure",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-hostfile", help="If provided, save the hostfile to this path"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--auto-setup",
|
||||||
|
"--no-auto-setup",
|
||||||
|
action=OptionalBoolAction,
|
||||||
|
nargs=0,
|
||||||
|
dest="auto_setup",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dot", action="store_true", help="Output the topology in DOT format and exit"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["ring", "jaccl"],
|
||||||
|
default=None,
|
||||||
|
help="Which distributed backend to configure",
|
||||||
|
)
|
||||||
|
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--hostfile-only", action="store_true", help="If set only compute the hostfile"
|
||||||
|
# )
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.hostfile is not None:
|
||||||
|
hosts = parse_hostfile(parser, args.hostfile)
|
||||||
|
else:
|
||||||
|
hosts = parse_hostlist(parser, args.hosts, 1)
|
||||||
|
|
||||||
|
# Check that we can ssh
|
||||||
|
log(
|
||||||
|
args.verbose,
|
||||||
|
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||||
|
)
|
||||||
|
check_ssh_connections(hosts)
|
||||||
|
|
||||||
|
# Prepare a hostfile for communication over ethernet using the ips of the
|
||||||
|
# provided hostnames
|
||||||
|
if args.over == "ethernet":
|
||||||
|
prepare_ethernet_hostfile(args, hosts)
|
||||||
|
|
||||||
|
# Configure the macs for communication over thunderbolt, both via RDMA and IP
|
||||||
|
else:
|
||||||
|
prepare_tb_hostfile(args, hosts)
|
||||||
|
|||||||
@@ -45,11 +45,13 @@ class CommandProcess:
|
|||||||
|
|
||||||
|
|
||||||
class RemoteProcess(CommandProcess):
|
class RemoteProcess(CommandProcess):
|
||||||
def __init__(self, rank, host, cwd, files, env, command):
|
def __init__(self, rank, host, python, cwd, files, env, command):
|
||||||
is_local = host == "127.0.0.1"
|
is_local = host == "127.0.0.1"
|
||||||
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
|
script = RemoteProcess.make_monitor_script(rank, cwd, files, env, command)
|
||||||
script_b64 = base64.b64encode(script.encode()).decode()
|
script_b64 = base64.b64encode(script.encode()).decode()
|
||||||
cmd = f'{sys.executable} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
cmd = (
|
||||||
|
f'{python} -c "import base64; exec(base64.b64decode(\\"{script_b64}\\"));"'
|
||||||
|
)
|
||||||
if not is_local:
|
if not is_local:
|
||||||
cmd = f"ssh {host} '{cmd}'"
|
cmd = f"ssh {host} '{cmd}'"
|
||||||
|
|
||||||
@@ -309,7 +311,7 @@ def launch_ring(parser, hosts, args, command):
|
|||||||
_launch_with_io(
|
_launch_with_io(
|
||||||
RemoteProcess,
|
RemoteProcess,
|
||||||
[
|
[
|
||||||
((rank, h.ssh_hostname, cwd, files, env, command), {})
|
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
|
||||||
for rank, h in enumerate(hosts)
|
for rank, h in enumerate(hosts)
|
||||||
],
|
],
|
||||||
args.verbose,
|
args.verbose,
|
||||||
@@ -341,6 +343,7 @@ def launch_nccl(parser, hosts, args, command):
|
|||||||
(
|
(
|
||||||
rank,
|
rank,
|
||||||
h.ssh_hostname,
|
h.ssh_hostname,
|
||||||
|
args.python,
|
||||||
cwd,
|
cwd,
|
||||||
{},
|
{},
|
||||||
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
|
env + [f"CUDA_VISIBLE_DEVICES={rank % args.repeat_hosts}"],
|
||||||
@@ -374,7 +377,7 @@ def launch_jaccl(parser, hosts, args, command):
|
|||||||
_launch_with_io(
|
_launch_with_io(
|
||||||
RemoteProcess,
|
RemoteProcess,
|
||||||
[
|
[
|
||||||
((rank, h.ssh_hostname, cwd, files, env, command), {})
|
((rank, h.ssh_hostname, args.python, cwd, files, env, command), {})
|
||||||
for rank, h in enumerate(hosts)
|
for rank, h in enumerate(hosts)
|
||||||
],
|
],
|
||||||
args.verbose,
|
args.verbose,
|
||||||
@@ -503,11 +506,20 @@ def main():
|
|||||||
default=12345,
|
default=12345,
|
||||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-verify-script",
|
||||||
|
action="store_false",
|
||||||
|
dest="verify_script",
|
||||||
|
help="Do not verify that the script exists",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--python", default=sys.executable, help="Use this python on the remote hosts"
|
||||||
|
)
|
||||||
|
|
||||||
args, rest = parser.parse_known_args()
|
args, rest = parser.parse_known_args()
|
||||||
|
|
||||||
if args.print_python:
|
if args.print_python:
|
||||||
print(sys.executable)
|
print(args.python)
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(rest) == 0:
|
if len(rest) == 0:
|
||||||
@@ -523,10 +535,10 @@ def main():
|
|||||||
|
|
||||||
# Check if the script is a file and convert it to a full path
|
# Check if the script is a file and convert it to a full path
|
||||||
if (script := Path(rest[0])).exists() and script.is_file():
|
if (script := Path(rest[0])).exists() and script.is_file():
|
||||||
rest[0:1] = [sys.executable, str(script.resolve())]
|
rest[0:1] = [args.python, str(script.resolve())]
|
||||||
elif (command := shutil.which(rest[0])) is not None:
|
elif (command := shutil.which(rest[0])) is not None:
|
||||||
rest[0] = command
|
rest[0] = command
|
||||||
else:
|
elif args.verify_script:
|
||||||
raise ValueError(f"Invalid script or command {rest[0]}")
|
raise ValueError(f"Invalid script or command {rest[0]}")
|
||||||
|
|
||||||
# Launch
|
# Launch
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -266,7 +266,7 @@ if __name__ == "__main__":
|
|||||||
entry_points = {
|
entry_points = {
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"mlx.launch = mlx._distributed_utils.launch:main",
|
"mlx.launch = mlx._distributed_utils.launch:main",
|
||||||
# "mlx.distributed_config = mlx.distributed_run:distributed_config",
|
"mlx.distributed_config = mlx._distributed_utils.config:main",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
install_requires = []
|
install_requires = []
|
||||||
|
|||||||
Reference in New Issue
Block a user