mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-06 16:51:24 +08:00
NCCL backend (#2476)
This commit is contained in:

committed by
GitHub

parent
e843c4d8d5
commit
9392fc3f88
284
python/tests/nccl_test_distributed.py
Normal file
284
python/tests/nccl_test_distributed.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||
from mlx.nn.utils import average_gradients
|
||||
|
||||
|
||||
class TestNCCLDistributed(mlx_tests.MLXTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
world = mx.distributed.init(strict=True, backend="nccl")
|
||||
rank = world.rank()
|
||||
mx.set_default_device(mx.Device(mx.gpu, rank % 8))
|
||||
|
||||
def test_all_reduce(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
]
|
||||
sizes = [
|
||||
(7,),
|
||||
(10,),
|
||||
(1024,),
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (
|
||||
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
|
||||
).astype(dt)
|
||||
|
||||
# All sum
|
||||
y = mx.distributed.all_sum(x[world.rank()])
|
||||
z = x.sum(0)
|
||||
maxrelerror = (y - z).abs()
|
||||
if rtol > 0:
|
||||
maxrelerror /= z.abs()
|
||||
maxrelerror = maxrelerror.max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
def test_average_gradients(self):
|
||||
original_all_sum = mx.distributed.all_sum
|
||||
n_calls = 0
|
||||
xtype = None
|
||||
|
||||
def new_all_sum(x, **kwargs):
|
||||
nonlocal n_calls
|
||||
nonlocal xtype
|
||||
|
||||
n_calls += 1
|
||||
if xtype is not None:
|
||||
self.assertEqual(xtype, x.dtype)
|
||||
|
||||
return original_all_sum(x, **kwargs)
|
||||
|
||||
mx.distributed.all_sum = new_all_sum
|
||||
try:
|
||||
grads = [mx.ones(10) for i in range(10)]
|
||||
new_grads = average_gradients(grads, stream=mx.gpu)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 1)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=4 * 50, stream=mx.gpu)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
n_calls = 0
|
||||
new_grads = average_gradients(grads, all_reduce_size=0, stream=mx.gpu)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 10)
|
||||
|
||||
n_calls = 0
|
||||
xtype = mx.float16
|
||||
new_grads = average_gradients(
|
||||
grads,
|
||||
all_reduce_size=2 * 50,
|
||||
communication_type=mx.float16,
|
||||
stream=mx.gpu,
|
||||
)
|
||||
mx.eval(new_grads)
|
||||
self.assertEqual(len(new_grads), 10)
|
||||
self.assertTrue(all(g.dtype == mx.float32 for g in new_grads))
|
||||
self.assertTrue(all(mx.all(g == 1) for g in new_grads))
|
||||
self.assertEqual(n_calls, 2)
|
||||
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
mx.synchronize()
|
||||
|
||||
mx.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
all_sum_only = mx.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
mx.synchronize()
|
||||
all_sum_with_binary = mx.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
||||
def test_shard_linear(self):
|
||||
# Seed the prng to have the same inputs and weights generated everywhere
|
||||
mx.random.seed(0xF0F0F0F0)
|
||||
|
||||
# Prepare inputs
|
||||
world = mx.distributed.init()
|
||||
part = (
|
||||
slice(None),
|
||||
slice(
|
||||
world.rank() * 1024 // world.size(),
|
||||
(world.rank() + 1) * 1024 // world.size(),
|
||||
),
|
||||
)
|
||||
x = mx.random.normal((4, 1024))
|
||||
|
||||
# Create and shard some linear layers
|
||||
lin = nn.Linear(1024, 1024, bias=True)
|
||||
slin1 = shard_linear(lin, "all-to-sharded")
|
||||
slin2 = shard_linear(lin, "sharded-to-all")
|
||||
y = lin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1, atol=1e-4, rtol=1e-4))
|
||||
|
||||
# Check the backward works as expected
|
||||
def dummy_loss(model, x, y):
|
||||
return (model(x) * y).sum()
|
||||
|
||||
mod = nn.Sequential(
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
nn.Linear(128, 128),
|
||||
)
|
||||
smod = nn.Sequential(
|
||||
shard_linear(mod.layers[0], "all-to-sharded"),
|
||||
shard_linear(mod.layers[1], "sharded-to-all"),
|
||||
shard_linear(mod.layers[2], "all-to-sharded"),
|
||||
shard_linear(mod.layers[3], "sharded-to-all"),
|
||||
)
|
||||
|
||||
grad1 = nn.value_and_grad(mod, dummy_loss)
|
||||
grad2 = nn.value_and_grad(smod, dummy_loss)
|
||||
|
||||
x = mx.random.normal((4, 128))
|
||||
y = mx.random.normal((4, 128))
|
||||
|
||||
l1, g1 = grad1(mod, x, y)
|
||||
l2, g2 = grad2(smod, x, y)
|
||||
mx.eval(l1, g1, l2, g2)
|
||||
|
||||
part = slice(
|
||||
world.rank() * 128 // world.size(), (world.rank() + 1) * 128 // world.size()
|
||||
)
|
||||
|
||||
self.assertTrue(mx.allclose(l1, l2))
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][0]["weight"][part],
|
||||
g2["layers"][0]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][2]["weight"][part],
|
||||
g2["layers"][2]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["weight"][:, part],
|
||||
g2["layers"][1]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["weight"][:, part],
|
||||
g2["layers"][3]["weight"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][0]["bias"][part],
|
||||
g2["layers"][0]["bias"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][2]["bias"][part],
|
||||
g2["layers"][2]["bias"],
|
||||
atol=1e-6,
|
||||
rtol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
def test_shard_predicate(self):
|
||||
mx.random.seed(0xF0F0F0F0)
|
||||
|
||||
class MyConv(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.aggregate = kwargs.pop("aggregate", False)
|
||||
self.conv = nn.Conv2d(*args, **kwargs)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv(x)
|
||||
if self.aggregate:
|
||||
x = mx.distributed.all_sum(x)
|
||||
return x
|
||||
|
||||
def sharding(path, weight):
|
||||
parts = path.split(".")
|
||||
even = int(parts[1]) % 2 == 0
|
||||
if even:
|
||||
return 0
|
||||
else:
|
||||
return -1 if parts[-1] != "bias" else None
|
||||
|
||||
mod = nn.Sequential(
|
||||
MyConv(3, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 3, kernel_size=3),
|
||||
)
|
||||
smod = nn.Sequential(
|
||||
MyConv(3, 128, kernel_size=3),
|
||||
MyConv(128, 128, kernel_size=3, aggregate=True),
|
||||
MyConv(128, 128, kernel_size=3),
|
||||
MyConv(128, 3, kernel_size=3, aggregate=True),
|
||||
)
|
||||
smod.update(mod.parameters())
|
||||
shard_inplace(smod, sharding)
|
||||
|
||||
x = mx.random.normal((4, 16, 16, 3))
|
||||
y1 = mod(x)
|
||||
y2 = smod(x)
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
Reference in New Issue
Block a user