NCCL backend (#2476)

This commit is contained in:
Anastasiia Filippova
2025-08-21 20:56:15 +02:00
committed by GitHub
parent e843c4d8d5
commit 9392fc3f88
21 changed files with 897 additions and 20 deletions

View File

@@ -415,6 +415,48 @@ def launch_mpi(parser, hosts, args, command):
pass
def launch_nccl(parser, hosts, args, command):
master_host = hosts[0].ips[0]
if master_host != "127.0.0.1":
raise ValueError("The NCCL backend only supports localhost for now. ")
master_port = args.nccl_port
world_size = len(hosts)
base_env = os.environ.copy()
base_env.update(
{
"NCCL_DEBUG": "INFO",
"NCCL_SOCKET_IFNAME": "lo", # Use loopback for local communication
"NCCL_HOST_IP": master_host,
"NCCL_PORT": str(master_port),
"MLX_WORLD_SIZE": str(world_size),
}
)
procs = []
try:
for rank in range(world_size):
env = base_env.copy()
env["MLX_RANK"] = str(rank)
env["CUDA_VISIBLE_DEVICES"] = str(rank % args.nproc_per_node)
p = Popen(command, env=env)
procs.append(p)
for p in procs:
ret = p.wait()
if ret != 0:
raise RuntimeError(f"Rank process exited with {ret}")
except (RuntimeError, KeyboardInterrupt) as err:
for p in procs:
if p.poll() is None:
try:
p.kill()
except Exception:
pass
raise
def check_ssh_connections(hosts):
results = [False] * len(hosts)
@@ -665,7 +707,7 @@ def distributed_config():
)
parser.add_argument(
"--backend",
choices=["ring", "mpi"],
choices=["ring", "mpi", "nccl"],
default="ring",
help="Which distributed backend to configure",
)
@@ -737,7 +779,7 @@ def main():
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--backend",
choices=["ring", "mpi"],
choices=["ring", "mpi", "nccl"],
default="ring",
help="Which distributed backend to launch",
)
@@ -769,6 +811,13 @@ def main():
parser.add_argument(
"--cwd", help="Set the working directory on each node to the provided one"
)
parser.add_argument(
"--nccl-port",
type=int,
default=12345,
help="The port to use for the NCCL communication (only for nccl backend)",
)
args, rest = parser.parse_known_args()
if rest[0] == "--":
rest.pop(0)
@@ -799,8 +848,10 @@ def main():
# Launch
if args.backend == "ring":
launch_ring(parser, hosts, args, rest)
elif args.backend == "mpi":
if args.backend == "mpi":
launch_mpi(parser, hosts, args, rest)
if args.backend == "nccl":
launch_nccl(parser, hosts, args, rest)
if __name__ == "__main__":

View File

@@ -76,6 +76,7 @@ def average_gradients(
group: Optional[mx.distributed.Group] = None,
all_reduce_size: int = 32 * 1024**2,
communication_type: Optional[mx.Dtype] = None,
stream: mx.Stream = mx.cpu,
):
"""Average the gradients across the distributed processes in the passed group.
@@ -94,6 +95,7 @@ def average_gradients(
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
stream (mlx.core.Stream): The stream to use for the reduction. Default: ``mlx.cpu``.
"""
group = group or mx.distributed.init()
N = group.size()
@@ -104,7 +106,7 @@ def average_gradients(
def _average(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_sum(x, stream=mx.cpu).astype(dt) / N
return mx.distributed.all_sum(x, stream=stream).astype(dt) / N
if all_reduce_size <= 0:
return tree_map(_average, gradients)

View File

@@ -79,7 +79,7 @@ void init_distributed(nb::module_& parent_module) {
in case ``mx.distributed.is_available()`` returns False otherwise
it throws a runtime error. Default: ``False``
backend (str, optional): Which distributed backend to initialize.
Possible values ``mpi``, ``ring``, ``any``. If set to ``any`` all
Possible values ``mpi``, ``ring``, ``nccl``, ``any``. If set to ``any`` all
available backends are tried and the first one that succeeds
becomes the global group which will be returned in subsequent
calls. Default: ``any``

View File

@@ -0,0 +1,284 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import mlx.nn as nn
import mlx_tests
from mlx.nn.layers.distributed import shard_inplace, shard_linear
from mlx.nn.utils import average_gradients
class TestNCCLDistributed(mlx_tests.MLXTestCase):
@classmethod
def setUpClass(cls):
world = mx.distributed.init(strict=True, backend="nccl")
rank = world.rank()
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
def test_all_reduce(self):
world = mx.distributed.init()
dtypes = [
(mx.int8, 0),
(mx.uint8, 0),
(mx.int32, 0),
(mx.uint32, 0),
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(7,),
(10,),
(1024,),
(1024, 1024),
]
key = mx.random.key(0)
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
# All sum
y = mx.distributed.all_sum(x[world.rank()])
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
def test_average_gradients(self):
original_all_sum = mx.distributed.all_sum
n_calls = 0
xtype = None
def new_all_sum(x, **kwargs):
nonlocal n_calls
nonlocal xtype
n_calls += 1
if xtype is not None:
self.assertEqual(xtype, x.dtype)
return original_all_sum(x, **kwargs)
mx.distributed.all_sum = new_all_sum
try:
grads = [mx.ones(10) for i in range(10)]
new_grads = average_gradients(grads, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 1)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
n_calls = 0
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 10)
n_calls = 0
xtype = mx.float16
new_grads = average_gradients(
grads,
all_reduce_size=2 * 50,
communication_type=mx.float16,
stream=mx.gpu,
)
mx.eval(new_grads)
self.assertEqual(len(new_grads), 10)
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
self.assertEqual(n_calls, 2)
finally:
mx.distributed.all_sum = original_all_sum
def test_donation(self):
x = mx.random.normal((1024,))
mx.eval(x)
mx.synchronize()
mx.reset_peak_memory()
scale = mx.array(2.0)
y = mx.distributed.all_sum(x)
mx.eval(y)
mx.synchronize()
all_sum_only = mx.get_peak_memory()
y = mx.distributed.all_sum(x) * scale
mx.eval(y)
mx.synchronize()
all_sum_with_binary = mx.get_peak_memory()
self.assertEqual(all_sum_only, all_sum_with_binary)
def test_shard_linear(self):
# Seed the prng to have the same inputs and weights generated everywhere
mx.random.seed(0xF0F0F0F0)
# Prepare inputs
world = mx.distributed.init()
part = (
slice(None),
slice(
world.rank() * 1024 // world.size(),
(world.rank() + 1) * 1024 // world.size(),
),
)
x = mx.random.normal((4, 1024))
# Create and shard some linear layers
lin = nn.Linear(1024, 1024, bias=True)
slin1 = shard_linear(lin, "all-to-sharded")
slin2 = shard_linear(lin, "sharded-to-all")
y = lin(x)
y1 = slin1(x)
y2 = slin2(x[part])
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
# Check the backward works as expected
def dummy_loss(model, x, y):
return (model(x) * y).sum()
mod = nn.Sequential(
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
nn.Linear(128, 128),
)
smod = nn.Sequential(
shard_linear(mod.layers[0], "all-to-sharded"),
shard_linear(mod.layers[1], "sharded-to-all"),
shard_linear(mod.layers[2], "all-to-sharded"),
shard_linear(mod.layers[3], "sharded-to-all"),
)
grad1 = nn.value_and_grad(mod, dummy_loss)
grad2 = nn.value_and_grad(smod, dummy_loss)
x = mx.random.normal((4, 128))
y = mx.random.normal((4, 128))
l1, g1 = grad1(mod, x, y)
l2, g2 = grad2(smod, x, y)
mx.eval(l1, g1, l2, g2)
part = slice(
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
)
self.assertTrue(mx.allclose(l1, l2))
self.assertTrue(
mx.allclose(
g1["layers"][0]["weight"][part],
g2["layers"][0]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["weight"][part],
g2["layers"][2]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["weight"][:, part],
g2["layers"][1]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["weight"][:, part],
g2["layers"][3]["weight"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][0]["bias"][part],
g2["layers"][0]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][2]["bias"][part],
g2["layers"][2]["bias"],
atol=1e-6,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
)
)
self.assertTrue(
mx.allclose(
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
)
)
def test_shard_predicate(self):
mx.random.seed(0xF0F0F0F0)
class MyConv(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.aggregate = kwargs.pop("aggregate", False)
self.conv = nn.Conv2d(*args, **kwargs)
def __call__(self, x):
x = self.conv(x)
if self.aggregate:
x = mx.distributed.all_sum(x)
return x
def sharding(path, weight):
parts = path.split(".")
even = int(parts[1]) % 2 == 0
if even:
return 0
else:
return -1 if parts[-1] != "bias" else None
mod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3),
)
smod = nn.Sequential(
MyConv(3, 128, kernel_size=3),
MyConv(128, 128, kernel_size=3, aggregate=True),
MyConv(128, 128, kernel_size=3),
MyConv(128, 3, kernel_size=3, aggregate=True),
)
smod.update(mod.parameters())
shard_inplace(smod, sharding)
x = mx.random.normal((4, 16, 16, 3))
y1 = mod(x)
y2 = smod(x)
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()