mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add more checks and improve errors
This commit is contained in:
@@ -35,6 +35,7 @@ def positive_number(x):
|
|||||||
def log(verbose, *args, **kwargs):
|
def log(verbose, *args, **kwargs):
|
||||||
if not verbose:
|
if not verbose:
|
||||||
return
|
return
|
||||||
|
kwargs["file"] = sys.stderr
|
||||||
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
print("\033[32m[INFO]", *args, "\033[0m", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -57,6 +57,45 @@ def add_ethernet_ips(hosts, verbose=False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_rdma(hosts, verbose=False):
|
||||||
|
# Check whether the hosts are capable of RDMA over thunderbolt
|
||||||
|
warn = False
|
||||||
|
for h in hosts:
|
||||||
|
log(verbose, "Checking that", h.ssh_hostname, "supports RDMA")
|
||||||
|
rdma_devs = (
|
||||||
|
run(["ssh", h.ssh_hostname, "ibv_devices"], capture_output=True, text=True)
|
||||||
|
.stdout.strip()
|
||||||
|
.split()
|
||||||
|
)
|
||||||
|
rdma_devs = [d for d in rdma_devs if d.startswith("rdma_")]
|
||||||
|
if not rdma_devs:
|
||||||
|
log_warning(h.ssh_hostname, "does not seem to have RDMA enabled")
|
||||||
|
warn = True
|
||||||
|
|
||||||
|
if warn:
|
||||||
|
log_warning()
|
||||||
|
log_warning(
|
||||||
|
"Some of the hosts don't have RDMA enabled or they don't support RDMA."
|
||||||
|
)
|
||||||
|
log_warning()
|
||||||
|
log_warning(
|
||||||
|
"See https://ml-explore.github.io/mlx/build/html/usage/distributed.html"
|
||||||
|
)
|
||||||
|
log_warning("for instructions on how to enable RDMA.")
|
||||||
|
|
||||||
|
|
||||||
|
def can_auto_setup(hosts, sshinfo, auto_setup=False):
|
||||||
|
has_sudo = all(info.has_sudo for info in sshinfo)
|
||||||
|
if not has_sudo and auto_setup:
|
||||||
|
log_warning(
|
||||||
|
"Automatic setup requested but the following hosts do not have passwordless sudo"
|
||||||
|
)
|
||||||
|
for h, i in zip(hosts, sshinfo):
|
||||||
|
if not i.has_sudo:
|
||||||
|
log_warning(" - ", h.ssh_hostname)
|
||||||
|
return has_sudo
|
||||||
|
|
||||||
|
|
||||||
class IPConfigurator:
|
class IPConfigurator:
|
||||||
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
|
def __init__(self, hosts, tb_hosts, uuid_reverse_index):
|
||||||
assigned = set()
|
assigned = set()
|
||||||
@@ -278,6 +317,8 @@ def check_valid_mesh(hosts, connectivity, strict=True):
|
|||||||
log_error(
|
log_error(
|
||||||
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
|
f"Incomplete mesh, {hosts[i].ssh_hostname} is not connected to {hosts[j].ssh_hostname}"
|
||||||
)
|
)
|
||||||
|
log_error()
|
||||||
|
log_error("Try passing --dot to visualize the connectivity")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
@@ -365,7 +406,7 @@ def prepare_ethernet_hostfile(args, hosts):
|
|||||||
print(json.dumps(hostfile, indent=4))
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
def configure_ring(args, hosts, ips, ring):
|
def configure_ring(args, hosts, ips, ring, sshinfo):
|
||||||
log(args.verbose, "Prepare a ring hostfile")
|
log(args.verbose, "Prepare a ring hostfile")
|
||||||
ring, count = ring
|
ring, count = ring
|
||||||
hostfile = []
|
hostfile = []
|
||||||
@@ -380,7 +421,8 @@ def configure_ring(args, hosts, ips, ring):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
|
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
|
||||||
|
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
|
||||||
|
|
||||||
if args.output_hostfile:
|
if args.output_hostfile:
|
||||||
with open(args.output_hostfile, "w") as f:
|
with open(args.output_hostfile, "w") as f:
|
||||||
@@ -391,9 +433,10 @@ def configure_ring(args, hosts, ips, ring):
|
|||||||
print(json.dumps(hostfile, indent=4))
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
def configure_jaccl(args, hosts, ips):
|
def configure_jaccl(args, hosts, ips, sshinfo):
|
||||||
log(args.verbose, "Prepare a jaccl hostfile")
|
log(args.verbose, "Prepare a jaccl hostfile")
|
||||||
add_ethernet_ips(hosts)
|
check_rdma(hosts, args.verbose)
|
||||||
|
add_ethernet_ips(hosts, args.verbose)
|
||||||
|
|
||||||
hostfile = []
|
hostfile = []
|
||||||
for i, h in enumerate(hosts):
|
for i, h in enumerate(hosts):
|
||||||
@@ -405,7 +448,8 @@ def configure_jaccl(args, hosts, ips):
|
|||||||
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
|
rdma.append(f"rdma_{ips.ips[i, j][0][0]}")
|
||||||
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
|
hostfile.append({"ssh": h.ssh_hostname, "ips": h.ips, "rdma": rdma})
|
||||||
|
|
||||||
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup)
|
has_sudo = can_auto_setup(hosts, sshinfo, args.auto_setup)
|
||||||
|
ips.setup(verbose=args.verbose, auto_setup=args.auto_setup and has_sudo)
|
||||||
|
|
||||||
if args.output_hostfile:
|
if args.output_hostfile:
|
||||||
with open(args.output_hostfile, "w") as f:
|
with open(args.output_hostfile, "w") as f:
|
||||||
@@ -416,7 +460,7 @@ def configure_jaccl(args, hosts, ips):
|
|||||||
print(json.dumps(hostfile, indent=4))
|
print(json.dumps(hostfile, indent=4))
|
||||||
|
|
||||||
|
|
||||||
def prepare_tb_hostfile(args, hosts):
|
def prepare_tb_hostfile(args, hosts, sshinfo):
|
||||||
log(args.verbose, f"Preparing for communication over thunderbolt")
|
log(args.verbose, f"Preparing for communication over thunderbolt")
|
||||||
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
|
tb_hosts, uuid_reverse_index = extract_connectivity(hosts, args.verbose)
|
||||||
|
|
||||||
@@ -438,26 +482,28 @@ def prepare_tb_hostfile(args, hosts):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
elif has_ring:
|
elif has_ring:
|
||||||
configure_ring(args, hosts, ips, rings[0])
|
configure_ring(args, hosts, ips, rings[0], sshinfo)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
configure_jaccl(args, hosts, ips)
|
configure_jaccl(args, hosts, ips, sshinfo)
|
||||||
|
|
||||||
elif args.backend == "ring":
|
elif args.backend == "ring":
|
||||||
rings = extract_rings(connectivity)
|
rings = extract_rings(connectivity)
|
||||||
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
has_ring = len(rings) > 0 and len(rings[0][0]) == len(hosts)
|
||||||
if not has_ring:
|
if not has_ring:
|
||||||
log_error("Could not find a full ring.")
|
log_error("Could not find a full ring.")
|
||||||
|
log_error()
|
||||||
|
log_error("Try passing --dot to visualize the connectivity")
|
||||||
if len(rings) > 0:
|
if len(rings) > 0:
|
||||||
log_error("Rings found:")
|
log_error("Rings found:")
|
||||||
for r in rings:
|
for r in rings:
|
||||||
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
|
log_error(f" - {','.join(hosts[i].ssh_hostname for i in r)}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
configure_ring(args, hosts, ips, rings[0])
|
configure_ring(args, hosts, ips, rings[0], sshinfo)
|
||||||
|
|
||||||
elif args.backend == "jaccl":
|
elif args.backend == "jaccl":
|
||||||
check_valid_mesh(hosts, connectivity)
|
check_valid_mesh(hosts, connectivity)
|
||||||
configure_jaccl(args, hosts, ips)
|
configure_jaccl(args, hosts, ips, sshinfo)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -498,10 +544,6 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Which distributed backend to configure",
|
help="Which distributed backend to configure",
|
||||||
)
|
)
|
||||||
|
|
||||||
# parser.add_argument(
|
|
||||||
# "--hostfile-only", action="store_true", help="If set only compute the hostfile"
|
|
||||||
# )
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.hostfile is not None:
|
if args.hostfile is not None:
|
||||||
@@ -514,7 +556,7 @@ def main():
|
|||||||
args.verbose,
|
args.verbose,
|
||||||
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
|
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
|
||||||
)
|
)
|
||||||
check_ssh_connections(hosts)
|
sshinfo = check_ssh_connections(hosts)
|
||||||
|
|
||||||
# Prepare a hostfile for communication over ethernet using the ips of the
|
# Prepare a hostfile for communication over ethernet using the ips of the
|
||||||
# provided hostnames
|
# provided hostnames
|
||||||
@@ -523,4 +565,4 @@ def main():
|
|||||||
|
|
||||||
# Configure the macs for communication over thunderbolt, both via RDMA and IP
|
# Configure the macs for communication over thunderbolt, both via RDMA and IP
|
||||||
else:
|
else:
|
||||||
prepare_tb_hostfile(args, hosts)
|
prepare_tb_hostfile(args, hosts, sshinfo)
|
||||||
|
|||||||
Reference in New Issue
Block a user