diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 424873dd3..448e3f954 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -636,9 +636,17 @@ def prepare_tb_ring(args, hosts): if ip0 > 255: raise ValueError("Ran out of available local IPs for the ring") + # Extract the host order from the first ring + hostmap = dict((r[0][0], r[1][0]) for r in rings[0]) + first_host = min(hostmap.keys()) + order = [first_host] + while hostmap[order[-1]] != first_host: + order.append(hostmap[order[-1]]) + # Create the hostfile hostfile = [] - for i, h in enumerate(hosts): + for i in order: + h = hosts[i] host = { "ssh": h.ssh_hostname, "ips": [