NCCL backend (#2476)

This commit is contained in:
Anastasiia Filippova
2025-08-21 20:56:15 +02:00
committed by GitHub
parent e843c4d8d5
commit 9392fc3f88
21 changed files with 897 additions and 20 deletions

View File

@@ -415,6 +415,48 @@ def launch_mpi(parser, hosts, args, command):
pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now. ")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": "INFO",
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
try:
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts):
results = [False] * len(hosts)
@@ -665,7 +707,7 @@ def distributed_config():
)
parser.add_argument(
"--backend",
choices=["ring", "mpi"],
choices=["ring", "mpi", "nccl"],
default="ring",
help="Which distributed backend to configure",
)
@@ -737,7 +779,7 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi"],
choices=["ring", "mpi", "nccl"],
default="ring",
help="Which distributed backend to launch",
)
@@ -769,6 +811,13 @@ def main():
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if rest[0] == "--":
rest.pop(0)
@@ -799,8 +848,10 @@ def main():
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
elif args.backend == "mpi":
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if __name__ == "__main__":

View File

@@ -76,6 +76,7 @@ def average_gradients(
group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None,
stream: mx.Stream = mx.cpu,
):
"""Average the gradients across the distributed processes in the passed group.
@@ -94,6 +95,7 @@ def average_gradients(
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
"""
group = group or mx.distributed.init()
N = group.size()
@@ -104,7 +106,7 @@ def average_gradients(
def _average(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
if all_reduce_size <= 0:
return tree_map(_average, gradients)