diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index bb0e3c633..424873dd3 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -567,13 +567,16 @@ def prepare_tb_ring(args, hosts): name = "" ports = [] for t in c["SPThunderboltDataType"]: + uuid = t.get("domain_uuid_key") + if uuid is None: + continue name = t["device_name_key"] - uuid = t["domain_uuid_key"] tag = t["receptacle_1_tag"]["receptacle_id_key"] - if items := t.get("_items", []): - connected_to = items[0]["domain_uuid_key"] - else: - connected_to = None + items = t.get("_items", []) + connected_items = [item for item in items if "domain_uuid_key" in item] + connected_to = ( + connected_items[0]["domain_uuid_key"] if connected_items else None + ) iface = iface_map[f"Thunderbolt {tag}"] ports.append(ThunderboltPort(iface, uuid, connected_to)) tb_hosts.append(ThunderboltHost(name, sorted(ports, key=lambda x: x.iface)))