nccl backend

This commit is contained in:
Anastasiia Filippova
2025-08-07 13:11:56 +02:00
parent 56be773610
commit f540b1d612
18 changed files with 869 additions and 12 deletions

View File

@@ -415,6 +415,45 @@ def launch_mpi(parser, hosts, args, command):
pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
master_port = args.nccl_port
world_size = args.nproc_per_node * args.nnodes
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": "INFO",
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
try:
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank)
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts):
results = [False] * len(hosts)
@@ -665,7 +704,7 @@ def distributed_config():
)
parser.add_argument(
"--backend",
choices=["ring", "mpi"],
choices=["ring", "mpi", "nccl"],
default="ring",
help="Which distributed backend to configure",
)
@@ -737,7 +776,7 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi"],
choices=["ring", "mpi", "nccl"],
default="ring",
help="Which distributed backend to launch",
)
@@ -769,6 +808,26 @@ def main():
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
parser.add_argument(
"--nproc-per-node",
default=8,
type=int,
help="The number of processes to launch on each node (only for nccl backend)",
)
# TODO: Add support for multiple nodes
parser.add_argument(
"--nnodes",
default=1,
type=int,
help="The number of nodes to launch (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if rest[0] == "--":
rest.pop(0)
@@ -799,8 +858,10 @@ def main():
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
elif args.backend == "mpi":
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if __name__ == "__main__":