mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add more checks and improve errors
This commit is contained in:
@@ -35,6 +35,7 @@ def positive_number(x):
|
||||
def log(verbose, *args, **kwargs):
|
||||
if not verbose:
|
||||
return
|
||||
kwargs["file"] = sys.stderr
|
||||
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -57,6 +57,45 @@ def add_ethernet_ips(hosts, verbose=False):
|
||||
)
|
||||
|
||||
|
||||
def check_rdma(hosts, verbose=False):
|
||||
# Check whether the hosts are capable of RDMA over thunderbolt
|
||||
warn = False
|
||||
for h in hosts:
|
||||
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
|
||||
rdma_devs = (
|
||||
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
|
||||
.stdout.strip()
|
||||
.split()
|
||||
)
|
||||
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
|
||||
if not rdma_devs:
|
||||
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
|
||||
warn = True
|
||||
|
||||
if warn:
|
||||
log_warning()
|
||||
log_warning(
|
||||
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
|
||||
)
|
||||
log_warning()
|
||||
log_warning(
|
||||
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
|
||||
)
|
||||
log_warning("for instructions on how to enable RDMA.")
|
||||
|
||||
|
||||
def can_auto_setup(hosts, sshinfo, auto_setup=False):
|
||||
has_sudo = all(info.has_sudo for info in sshinfo)
|
||||
if not has_sudo and auto_setup:
|
||||
log_warning(
|
||||
"Automatic setup requested but the following hosts do not have passwordless sudo"
|
||||
)
|
||||
for h, i in zip(hosts, sshinfo):
|
||||
if not i.has_sudo:
|
||||
log_warning(" - ", h.ssh_hostname)
|
||||
return has_sudo
|
||||
|
||||
|
||||
class IPConfigurator:
|
||||
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
|
||||
assigned = set()
|
||||
@@ -278,6 +317,8 @@ def check_valid_mesh(hosts, connectivity, strict=True):
|
||||
log_error(
|
||||
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
|
||||
)
|
||||
log_error()
|
||||
log_error("Try passing --dot to visualize the connectivity")
|
||||
sys.exit(1)
|
||||
else:
|
||||
return False
|
||||
@@ -365,7 +406,7 @@ def prepare_ethernet_hostfile(args, hosts):
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def configure_ring(args, hosts, ips, ring):
|
||||
def configure_ring(args, hosts, ips, ring, sshinfo):
|
||||
log(args.verbose, "Prepare a ring hostfile")
|
||||
ring, count = ring
|
||||
hostfile = []
|
||||
@@ -380,7 +421,8 @@ def configure_ring(args, hosts, ips, ring):
|
||||
}
|
||||
)
|
||||
|
||||
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
|
||||
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
|
||||
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
@@ -391,9 +433,10 @@ def configure_ring(args, hosts, ips, ring):
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def configure_jaccl(args, hosts, ips):
|
||||
def configure_jaccl(args, hosts, ips, sshinfo):
|
||||
log(args.verbose, "Prepare a jaccl hostfile")
|
||||
add_ethernet_ips(hosts)
|
||||
check_rdma(hosts, args.verbose)
|
||||
add_ethernet_ips(hosts, args.verbose)
|
||||
|
||||
hostfile = []
|
||||
for i, h in enumerate(hosts):
|
||||
@@ -405,7 +448,8 @@ def configure_jaccl(args, hosts, ips):
|
||||
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
|
||||
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
|
||||
|
||||
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
|
||||
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
|
||||
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
|
||||
|
||||
if args.output_hostfile:
|
||||
with open(args.output_hostfile, "w") as f:
|
||||
@@ -416,7 +460,7 @@ def configure_jaccl(args, hosts, ips):
|
||||
print(json.dumps(hostfile, indent=4))
|
||||
|
||||
|
||||
def prepare_tb_hostfile(args, hosts):
|
||||
def prepare_tb_hostfile(args, hosts, sshinfo):
|
||||
log(args.verbose, f"Preparing for communication over thunderbolt")
|
||||
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
|
||||
|
||||
@@ -438,26 +482,28 @@ def prepare_tb_hostfile(args, hosts):
|
||||
sys.exit(1)
|
||||
|
||||
elif has_ring:
|
||||
configure_ring(args, hosts, ips, rings[0])
|
||||
configure_ring(args, hosts, ips, rings[0], sshinfo)
|
||||
|
||||
else:
|
||||
configure_jaccl(args, hosts, ips)
|
||||
configure_jaccl(args, hosts, ips, sshinfo)
|
||||
|
||||
elif args.backend == "ring":
|
||||
rings = extract_rings(connectivity)
|
||||
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
||||
if not has_ring:
|
||||
log_error("Could not find a full ring.")
|
||||
log_error()
|
||||
log_error("Try passing --dot to visualize the connectivity")
|
||||
if len(rings) > 0:
|
||||
log_error("Rings found:")
|
||||
for r in rings:
|
||||
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
|
||||
sys.exit(1)
|
||||
configure_ring(args, hosts, ips, rings[0])
|
||||
configure_ring(args, hosts, ips, rings[0], sshinfo)
|
||||
|
||||
elif args.backend == "jaccl":
|
||||
check_valid_mesh(hosts, connectivity)
|
||||
configure_jaccl(args, hosts, ips)
|
||||
configure_jaccl(args, hosts, ips, sshinfo)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -498,10 +544,6 @@ def main():
|
||||
default=None,
|
||||
help="Which distributed backend to configure",
|
||||
)
|
||||
|
||||
# parser.add_argument(
|
||||
# "--hostfile-only", action="store_true", help="If set only compute the hostfile"
|
||||
# )
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.hostfile is not None:
|
||||
@@ -514,7 +556,7 @@ def main():
|
||||
args.verbose,
|
||||
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||
)
|
||||
check_ssh_connections(hosts)
|
||||
sshinfo = check_ssh_connections(hosts)
|
||||
|
||||
# Prepare a hostfile for communication over ethernet using the ips of the
|
||||
# provided hostnames
|
||||
@@ -523,4 +565,4 @@ def main():
|
||||
|
||||
# Configure the macs for communication over thunderbolt, both via RDMA and IP
|
||||
else:
|
||||
prepare_tb_hostfile(args, hosts)
|
||||
prepare_tb_hostfile(args, hosts, sshinfo)
|
||||
|
||||
Reference in New Issue
Block a user