From fa31a4b2950d16d9090d350d6f201df3ee2833bb Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 9 Dec 2025 13:36:17 -0800 Subject: [PATCH] Add more checks and improve errors --- python/mlx/_distributed_utils/common.py | 1 + python/mlx/_distributed_utils/config.py | 74 +++++++++++++++++++------ 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/python/mlx/_distributed_utils/common.py b/python/mlx/_distributed_utils/common.py index a466668ff..16bf3f2be 100644 --- a/python/mlx/_distributed_utils/common.py +++ b/python/mlx/_distributed_utils/common.py @@ -35,6 +35,7 @@ def positive_number(x): def log(verbose, *args, **kwargs): if not verbose: return + kwargs["file"] = sys.stderr print("\033[32m[INFO]", *args, "\033[0m", **kwargs) diff --git a/python/mlx/_distributed_utils/config.py b/python/mlx/_distributed_utils/config.py index 07750125d..e19c3083e 100644 --- a/python/mlx/_distributed_utils/config.py +++ b/python/mlx/_distributed_utils/config.py @@ -57,6 +57,45 @@ def add_ethernet_ips(hosts, verbose=False): ) +def check_rdma(hosts, verbose=False): + # Check whether the hosts are capable of RDMA over thunderbolt + warn = False + for h in hosts: + log(verbose, "Checking that", h.ssh_hostname, "supports RDMA") + rdma_devs = ( + run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True) + .stdout.strip() + .split() + ) + rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")] + if not rdma_devs: + log_warning(h.ssh_hostname, "does not seem to have RDMA enabled") + warn = True + + if warn: + log_warning() + log_warning( + "Some of the hosts don't have RDMA enabled or they don't support RDMA." + ) + log_warning() + log_warning( + "See https://ml-explore.github.io/mlx/build/html/usage/distributed.html" + ) + log_warning("for instructions on how to enable RDMA.") + + +def can_auto_setup(hosts, sshinfo, auto_setup=False): + has_sudo = all(info.has_sudo for info in sshinfo) + if not has_sudo and auto_setup: + log_warning( + "Automatic setup requested but the following hosts do not have passwordless sudo" + ) + for h, i in zip(hosts, sshinfo): + if not i.has_sudo: + log_warning(" - ", h.ssh_hostname) + return has_sudo + + class IPConfigurator: def __init__(self, hosts, tb_hosts, uuid_reverse_index): assigned = set() @@ -278,6 +317,8 @@ def check_valid_mesh(hosts, connectivity, strict=True): log_error( f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}" ) + log_error() + log_error("Try passing --dot to visualize the connectivity") sys.exit(1) else: return False @@ -365,7 +406,7 @@ def prepare_ethernet_hostfile(args, hosts): print(json.dumps(hostfile, indent=4)) -def configure_ring(args, hosts, ips, ring): +def configure_ring(args, hosts, ips, ring, sshinfo): log(args.verbose, "Prepare a ring hostfile") ring, count = ring hostfile = [] @@ -380,7 +421,8 @@ def configure_ring(args, hosts, ips, ring): } ) - ips.setup(verbose=args.verbose, auto_setup=args.auto_setup) + has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup) + ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo) if args.output_hostfile: with open(args.output_hostfile, "w") as f: @@ -391,9 +433,10 @@ def configure_ring(args, hosts, ips, ring): print(json.dumps(hostfile, indent=4)) -def configure_jaccl(args, hosts, ips): +def configure_jaccl(args, hosts, ips, sshinfo): log(args.verbose, "Prepare a jaccl hostfile") - add_ethernet_ips(hosts) + check_rdma(hosts, args.verbose) + add_ethernet_ips(hosts, args.verbose) hostfile = [] for i, h in enumerate(hosts): @@ -405,7 +448,8 @@ def configure_jaccl(args, hosts, ips): rdma.append(f"rdma_{ips.ips[i, j][0][0]}") hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma}) - ips.setup(verbose=args.verbose, auto_setup=args.auto_setup) + has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup) + ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo) if args.output_hostfile: with open(args.output_hostfile, "w") as f: @@ -416,7 +460,7 @@ def configure_jaccl(args, hosts, ips): print(json.dumps(hostfile, indent=4)) -def prepare_tb_hostfile(args, hosts): +def prepare_tb_hostfile(args, hosts, sshinfo): log(args.verbose, f"Preparing for communication over thunderbolt") tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose) @@ -438,26 +482,28 @@ def prepare_tb_hostfile(args, hosts): sys.exit(1) elif has_ring: - configure_ring(args, hosts, ips, rings[0]) + configure_ring(args, hosts, ips, rings[0], sshinfo) else: - configure_jaccl(args, hosts, ips) + configure_jaccl(args, hosts, ips, sshinfo) elif args.backend == "ring": rings = extract_rings(connectivity) has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts) if not has_ring: log_error("Could not find a full ring.") + log_error() + log_error("Try passing --dot to visualize the connectivity") if len(rings) > 0: log_error("Rings found:") for r in rings: log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}") sys.exit(1) - configure_ring(args, hosts, ips, rings[0]) + configure_ring(args, hosts, ips, rings[0], sshinfo) elif args.backend == "jaccl": check_valid_mesh(hosts, connectivity) - configure_jaccl(args, hosts, ips) + configure_jaccl(args, hosts, ips, sshinfo) def main(): @@ -498,10 +544,6 @@ def main(): default=None, help="Which distributed backend to configure", ) - - # parser.add_argument( - # "--hostfile-only", action="store_true", help="If set only compute the hostfile" - # ) args = parser.parse_args() if args.hostfile is not None: @@ -514,7 +556,7 @@ def main(): args.verbose, f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}", ) - check_ssh_connections(hosts) + sshinfo = check_ssh_connections(hosts) # Prepare a hostfile for communication over ethernet using the ips of the # provided hostnames @@ -523,4 +565,4 @@ def main(): # Configure the macs for communication over thunderbolt, both via RDMA and IP else: - prepare_tb_hostfile(args, hosts) + prepare_tb_hostfile(args, hosts, sshinfo)