Add more checks and improve errors

This commit is contained in:
Angelos Katharopoulos
2025-12-09 13:36:17 -08:00
parent 9d707ba3b5
commit fa31a4b295
2 changed files with 59 additions and 16 deletions

View File

@@ -35,6 +35,7 @@ def positive_number(x):
def log(verbose, *args, **kwargs): def log(verbose, *args, **kwargs):
if not verbose: if not verbose:
return return
kwargs["file"] = sys.stderr
print("\033[32m[INFO]", *args, "\033[0m", **kwargs) print("\033[32m[INFO]", *args, "\033[0m", **kwargs)

View File

@@ -57,6 +57,45 @@ def add_ethernet_ips(hosts, verbose=False):
) )
def check_rdma(hosts, verbose=False):
# Check whether the hosts are capable of RDMA over thunderbolt
warn = False
for h in hosts:
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
rdma_devs = (
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
.stdout.strip()
.split()
)
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
if not rdma_devs:
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
warn = True
if warn:
log_warning()
log_warning(
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
)
log_warning()
log_warning(
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
)
log_warning("for instructions on how to enable RDMA.")
def can_auto_setup(hosts, sshinfo, auto_setup=False):
has_sudo = all(info.has_sudo for info in sshinfo)
if not has_sudo and auto_setup:
log_warning(
"Automatic setup requested but the following hosts do not have passwordless sudo"
)
for h, i in zip(hosts, sshinfo):
if not i.has_sudo:
log_warning(" - ", h.ssh_hostname)
return has_sudo
class IPConfigurator: class IPConfigurator:
def __init__(self, hosts, tb_hosts, uuid_reverse_index): def __init__(self, hosts, tb_hosts, uuid_reverse_index):
assigned = set() assigned = set()
@@ -278,6 +317,8 @@ def check_valid_mesh(hosts, connectivity, strict=True):
log_error( log_error(
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}" f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
) )
log_error()
log_error("Try passing --dot to visualize the connectivity")
sys.exit(1) sys.exit(1)
else: else:
return False return False
@@ -365,7 +406,7 @@ def prepare_ethernet_hostfile(args, hosts):
print(json.dumps(hostfile, indent=4)) print(json.dumps(hostfile, indent=4))
def configure_ring(args, hosts, ips, ring): def configure_ring(args, hosts, ips, ring, sshinfo):
log(args.verbose, "Prepare a ring hostfile") log(args.verbose, "Prepare a ring hostfile")
ring, count = ring ring, count = ring
hostfile = [] hostfile = []
@@ -380,7 +421,8 @@ def configure_ring(args, hosts, ips, ring):
} }
) )
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup) has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile: if args.output_hostfile:
with open(args.output_hostfile, "w") as f: with open(args.output_hostfile, "w") as f:
@@ -391,9 +433,10 @@ def configure_ring(args, hosts, ips, ring):
print(json.dumps(hostfile, indent=4)) print(json.dumps(hostfile, indent=4))
def configure_jaccl(args, hosts, ips): def configure_jaccl(args, hosts, ips, sshinfo):
log(args.verbose, "Prepare a jaccl hostfile") log(args.verbose, "Prepare a jaccl hostfile")
add_ethernet_ips(hosts) check_rdma(hosts, args.verbose)
add_ethernet_ips(hosts, args.verbose)
hostfile = [] hostfile = []
for i, h in enumerate(hosts): for i, h in enumerate(hosts):
@@ -405,7 +448,8 @@ def configure_jaccl(args, hosts, ips):
rdma.append(f"rdma_{ips.ips[i, j][0][0]}") rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma}) hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup) has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
if args.output_hostfile: if args.output_hostfile:
with open(args.output_hostfile, "w") as f: with open(args.output_hostfile, "w") as f:
@@ -416,7 +460,7 @@ def configure_jaccl(args, hosts, ips):
print(json.dumps(hostfile, indent=4)) print(json.dumps(hostfile, indent=4))
def prepare_tb_hostfile(args, hosts): def prepare_tb_hostfile(args, hosts, sshinfo):
log(args.verbose, f"Preparing for communication over thunderbolt") log(args.verbose, f"Preparing for communication over thunderbolt")
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose) tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
@@ -438,26 +482,28 @@ def prepare_tb_hostfile(args, hosts):
sys.exit(1) sys.exit(1)
elif has_ring: elif has_ring:
configure_ring(args, hosts, ips, rings[0]) configure_ring(args, hosts, ips, rings[0], sshinfo)
else: else:
configure_jaccl(args, hosts, ips) configure_jaccl(args, hosts, ips, sshinfo)
elif args.backend == "ring": elif args.backend == "ring":
rings = extract_rings(connectivity) rings = extract_rings(connectivity)
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts) has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
if not has_ring: if not has_ring:
log_error("Could not find a full ring.") log_error("Could not find a full ring.")
log_error()
log_error("Try passing --dot to visualize the connectivity")
if len(rings) > 0: if len(rings) > 0:
log_error("Rings found:") log_error("Rings found:")
for r in rings: for r in rings:
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}") log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
sys.exit(1) sys.exit(1)
configure_ring(args, hosts, ips, rings[0]) configure_ring(args, hosts, ips, rings[0], sshinfo)
elif args.backend == "jaccl": elif args.backend == "jaccl":
check_valid_mesh(hosts, connectivity) check_valid_mesh(hosts, connectivity)
configure_jaccl(args, hosts, ips) configure_jaccl(args, hosts, ips, sshinfo)
def main(): def main():
@@ -498,10 +544,6 @@ def main():
default=None, default=None,
help="Which distributed backend to configure", help="Which distributed backend to configure",
) )
# parser.add_argument(
# "--hostfile-only", action="store_true", help="If set only compute the hostfile"
# )
args = parser.parse_args() args = parser.parse_args()
if args.hostfile is not None: if args.hostfile is not None:
@@ -514,7 +556,7 @@ def main():
args.verbose, args.verbose,
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}", f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
) )
check_ssh_connections(hosts) sshinfo = check_ssh_connections(hosts)
# Prepare a hostfile for communication over ethernet using the ips of the # Prepare a hostfile for communication over ethernet using the ips of the
# provided hostnames # provided hostnames
@@ -523,4 +565,4 @@ def main():
# Configure the macs for communication over thunderbolt, both via RDMA and IP # Configure the macs for communication over thunderbolt, both via RDMA and IP
else: else:
prepare_tb_hostfile(args, hosts) prepare_tb_hostfile(args, hosts, sshinfo)