mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
NCCL backend (#2476)
This commit is contained in:
committed by
GitHub
parent
e843c4d8d5
commit
9392fc3f88
@@ -415,6 +415,48 @@ def launch_mpi(parser, hosts, args, command):
|
||||
pass
|
||||
|
||||
|
||||
def launch_nccl(parser, hosts, args, command):
|
||||
master_host = hosts[0].ips[0]
|
||||
|
||||
if master_host != "127.0.0.1":
|
||||
raise ValueError("The NCCL backend only supports localhost for now. ")
|
||||
master_port = args.nccl_port
|
||||
world_size = len(hosts)
|
||||
|
||||
base_env = os.environ.copy()
|
||||
base_env.update(
|
||||
{
|
||||
"NCCL_DEBUG": "INFO",
|
||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||
"NCCL_HOST_IP": master_host,
|
||||
"NCCL_PORT": str(master_port),
|
||||
"MLX_WORLD_SIZE": str(world_size),
|
||||
}
|
||||
)
|
||||
procs = []
|
||||
try:
|
||||
for rank in range(world_size):
|
||||
env = base_env.copy()
|
||||
env["MLX_RANK"] = str(rank)
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
||||
p = Popen(command, env=env)
|
||||
procs.append(p)
|
||||
|
||||
for p in procs:
|
||||
ret = p.wait()
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Rank process exited with {ret}")
|
||||
|
||||
except (RuntimeError, KeyboardInterrupt) as err:
|
||||
for p in procs:
|
||||
if p.poll() is None:
|
||||
try:
|
||||
p.kill()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def check_ssh_connections(hosts):
|
||||
results = [False] * len(hosts)
|
||||
|
||||
@@ -665,7 +707,7 @@ def distributed_config():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi"],
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="ring",
|
||||
help="Which distributed backend to configure",
|
||||
)
|
||||
@@ -737,7 +779,7 @@ def main():
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi"],
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="ring",
|
||||
help="Which distributed backend to launch",
|
||||
)
|
||||
@@ -769,6 +811,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--cwd", help="Set the working directory on each node to the provided one"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nccl-port",
|
||||
type=int,
|
||||
default=12345,
|
||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||
)
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
if rest[0] == "--":
|
||||
rest.pop(0)
|
||||
@@ -799,8 +848,10 @@ def main():
|
||||
# Launch
|
||||
if args.backend == "ring":
|
||||
launch_ring(parser, hosts, args, rest)
|
||||
elif args.backend == "mpi":
|
||||
if args.backend == "mpi":
|
||||
launch_mpi(parser, hosts, args, rest)
|
||||
if args.backend == "nccl":
|
||||
launch_nccl(parser, hosts, args, rest)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user