Add more checks and improve errors

This commit is contained in:
Angelos Katharopoulos
2025-12-09 13:36:17 -08:00
parent 9d707ba3b5
commit fa31a4b295
2 changed files with 59 additions and 16 deletions

View File

@@ -35,6 +35,7 @@ def positive_number(x):
def log(verbose, *args, **kwargs):
if not verbose:
return
kwargs["file"] = sys.stderr
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)

View File

@@ -57,6 +57,45 @@ def add_ethernet_ips(hosts, verbose=False):
)
def check_rdma(hosts, verbose=False):
# Check whether the hosts are capable of RDMA over thunderbolt
warn = False
for h in hosts:
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
rdma_devs = (
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
.stdout.strip()
.split()
)
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
if not rdma_devs:
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
warn = True
if warn:
log_warning()
log_warning(
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
)
log_warning()
log_warning(
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
)
log_warning("for instructions on how to enable RDMA.")
def can_auto_setup(hosts, sshinfo, auto_setup=False):
has_sudo = all(info.has_sudo for info in sshinfo)
if not has_sudo and auto_setup:
log_warning(
"Automatic setup requested but the following hosts do not have passwordless sudo"
)
for h, i in zip(hosts, sshinfo):
if not i.has_sudo:
log_warning(" - ", h.ssh_hostname)
return has_sudo
class IPConfigurator:
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
assigned = set()
@@ -278,6 +317,8 @@ def check_valid_mesh(hosts, connectivity, strict=True):
log_error(
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
)
log_error()
log_error("Try passing --dot to visualize the connectivity")
sys.exit(1)
else:
return False
@@ -365,7 +406,7 @@ def prepare_ethernet_hostfile(args, hosts):
print(json.dumps(hostfile, indent=4))
def configure_ring(args, hosts, ips, ring):
def configure_ring(args, hosts, ips, ring, sshinfo):
log(args.verbose, "Prepare a ring hostfile")
ring, count = ring
hostfile = []
@@ -380,7 +421,8 @@ def configure_ring(args, hosts, ips, ring):
}
)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
@@ -391,9 +433,10 @@ def configure_ring(args, hosts, ips, ring):
print(json.dumps(hostfile, indent=4))
def configure_jaccl(args, hosts, ips):
def configure_jaccl(args, hosts, ips, sshinfo):
log(args.verbose, "Prepare a jaccl hostfile")
add_ethernet_ips(hosts)
check_rdma(hosts, args.verbose)
add_ethernet_ips(hosts, args.verbose)
hostfile = []
for i, h in enumerate(hosts):
@@ -405,7 +448,8 @@ def configure_jaccl(args, hosts, ips):
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile:
with open(args.output_hostfile, "w") as f:
@@ -416,7 +460,7 @@ def configure_jaccl(args, hosts, ips):
print(json.dumps(hostfile, indent=4))
def prepare_tb_hostfile(args, hosts):
def prepare_tb_hostfile(args, hosts, sshinfo):
log(args.verbose, f"Preparing for communication over thunderbolt")
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
@@ -438,26 +482,28 @@ def prepare_tb_hostfile(args, hosts):
sys.exit(1)
elif has_ring:
configure_ring(args, hosts, ips, rings[0])
configure_ring(args, hosts, ips, rings[0], sshinfo)
else:
configure_jaccl(args, hosts, ips)
configure_jaccl(args, hosts, ips, sshinfo)
elif args.backend == "ring":
rings = extract_rings(connectivity)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring:
log_error("Could not find a full ring.")
log_error()
log_error("Try passing --dot to visualize the connectivity")
if len(rings) > 0:
log_error("Rings found:")
for r in rings:
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
sys.exit(1)
configure_ring(args, hosts, ips, rings[0])
configure_ring(args, hosts, ips, rings[0], sshinfo)
elif args.backend == "jaccl":
check_valid_mesh(hosts, connectivity)
configure_jaccl(args, hosts, ips)
configure_jaccl(args, hosts, ips, sshinfo)
def main():
@@ -498,10 +544,6 @@ def main():
default=None,
help="Which distributed backend to configure",
)
# parser.add_argument(
# "--hostfile-only", action="store_true", help="If set only compute the hostfile"
# )
args = parser.parse_args()
if args.hostfile is not None:
@@ -514,7 +556,7 @@ def main():
args.verbose,
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
)
check_ssh_connections(hosts)
sshinfo = check_ssh_connections(hosts)
# Prepare a hostfile for communication over ethernet using the ips of the
# provided hostnames
@@ -523,4 +565,4 @@ def main():
# Configure the macs for communication over thunderbolt, both via RDMA and IP
else:
prepare_tb_hostfile(args, hosts)
prepare_tb_hostfile(args, hosts, sshinfo)