mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-05 11:28:12 +08:00
NCCL backend (#2476)
This commit is contained in:
committed by
GitHub
parent
e843c4d8d5
commit
9392fc3f88
@@ -415,6 +415,48 @@ def launch_mpi(parser, hosts, args, command):
|
||||
pass
|
||||
|
||||
|
||||
def launch_nccl(parser, hosts, args, command):
|
||||
master_host = hosts[0].ips[0]
|
||||
|
||||
if master_host != "127.0.0.1":
|
||||
raise ValueError("The NCCL backend only supports localhost for now. ")
|
||||
master_port = args.nccl_port
|
||||
world_size = len(hosts)
|
||||
|
||||
base_env = os.environ.copy()
|
||||
base_env.update(
|
||||
{
|
||||
"NCCL_DEBUG": "INFO",
|
||||
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
|
||||
"NCCL_HOST_IP": master_host,
|
||||
"NCCL_PORT": str(master_port),
|
||||
"MLX_WORLD_SIZE": str(world_size),
|
||||
}
|
||||
)
|
||||
procs = []
|
||||
try:
|
||||
for rank in range(world_size):
|
||||
env = base_env.copy()
|
||||
env["MLX_RANK"] = str(rank)
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
|
||||
p = Popen(command, env=env)
|
||||
procs.append(p)
|
||||
|
||||
for p in procs:
|
||||
ret = p.wait()
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Rank process exited with {ret}")
|
||||
|
||||
except (RuntimeError, KeyboardInterrupt) as err:
|
||||
for p in procs:
|
||||
if p.poll() is None:
|
||||
try:
|
||||
p.kill()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
|
||||
def check_ssh_connections(hosts):
|
||||
results = [False] * len(hosts)
|
||||
|
||||
@@ -665,7 +707,7 @@ def distributed_config():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi"],
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="ring",
|
||||
help="Which distributed backend to configure",
|
||||
)
|
||||
@@ -737,7 +779,7 @@ def main():
|
||||
parser.add_argument("--hostfile", help="The file containing the hosts")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
choices=["ring", "mpi"],
|
||||
choices=["ring", "mpi", "nccl"],
|
||||
default="ring",
|
||||
help="Which distributed backend to launch",
|
||||
)
|
||||
@@ -769,6 +811,13 @@ def main():
|
||||
parser.add_argument(
|
||||
"--cwd", help="Set the working directory on each node to the provided one"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nccl-port",
|
||||
type=int,
|
||||
default=12345,
|
||||
help="The port to use for the NCCL communication (only for nccl backend)",
|
||||
)
|
||||
|
||||
args, rest = parser.parse_known_args()
|
||||
if rest[0] == "--":
|
||||
rest.pop(0)
|
||||
@@ -799,8 +848,10 @@ def main():
|
||||
# Launch
|
||||
if args.backend == "ring":
|
||||
launch_ring(parser, hosts, args, rest)
|
||||
elif args.backend == "mpi":
|
||||
if args.backend == "mpi":
|
||||
launch_mpi(parser, hosts, args, rest)
|
||||
if args.backend == "nccl":
|
||||
launch_nccl(parser, hosts, args, rest)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -76,6 +76,7 @@ def average_gradients(
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
all_reduce_size: int = 32 * 1024**2,
|
||||
communication_type: Optional[mx.Dtype] = None,
|
||||
stream: mx.Stream = mx.cpu,
|
||||
):
|
||||
"""Average the gradients across the distributed processes in the passed group.
|
||||
|
||||
@@ -94,6 +95,7 @@ def average_gradients(
|
||||
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
|
||||
type before performing the communication. Typically cast to a
|
||||
smaller float to reduce the communication size. Default: ``None``.
|
||||
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
|
||||
"""
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
@@ -104,7 +106,7 @@ def average_gradients(
|
||||
def _average(x):
|
||||
dt = x.dtype
|
||||
x = x.astype(communication_type) if communication_type is not None else x
|
||||
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
|
||||
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
|
||||
|
||||
if all_reduce_size <= 0:
|
||||
return tree_map(_average, gradients)
|
||||
|
||||
Reference in New Issue
Block a user