mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-15 01:19:21 +08:00
Nccl reduce scatter, all gather (#2727)
* Added reduce scatter and all gather for nccl * fix unused import, delete unused file * small fix * deleted useless condition * fixed comments * fix bug in eval_gpu, renamed to sum_scatter, fix docs * final fix docs * remove and * Update mlx/distributed/mpi/mpi.cpp Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * fix broken set input output * fixes set output * typo * fix typo * no cpu, no gpu for reduce scatter --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
committed by
GitHub
parent
761f901a41
commit
27778156dc
@@ -1,7 +1,5 @@
|
||||
# Copyright © 2025 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx_tests
|
||||
@@ -63,6 +61,48 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_all_reduce(self):
|
||||
g = mx.distributed.init()
|
||||
dtypes = [
|
||||
(mx.int8, 0),
|
||||
(mx.uint8, 0),
|
||||
(mx.int32, 0),
|
||||
(mx.uint32, 0),
|
||||
(mx.float32, 1e-6),
|
||||
(mx.float16, 5e-3),
|
||||
(mx.bfloat16, 1e-1),
|
||||
]
|
||||
sizes = [
|
||||
(7,),
|
||||
(10,),
|
||||
(1024,),
|
||||
(1024, 1024),
|
||||
]
|
||||
key = mx.random.key(0)
|
||||
|
||||
for dt, rtol in dtypes:
|
||||
for sh in sizes:
|
||||
x = (mx.random.uniform(shape=(g.size(),) + sh, key=key) * 10).astype(dt)
|
||||
|
||||
# All sum
|
||||
y = mx.distributed.all_sum(x[g.rank()], group=g)
|
||||
z = x.sum(0)
|
||||
maxrelerror = (y - z).abs()
|
||||
if rtol > 0:
|
||||
maxrelerror /= z.abs()
|
||||
maxrelerror = maxrelerror.max()
|
||||
self.assertLessEqual(maxrelerror, rtol)
|
||||
|
||||
# All max
|
||||
y = mx.distributed.all_max(x[g.rank()], group=g)
|
||||
z = x.max(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
# All min
|
||||
y = mx.distributed.all_min(x[g.rank()], group=g)
|
||||
z = x.min(0)
|
||||
self.assertTrue(mx.all(y == z))
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
@@ -103,18 +143,19 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
y = lin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
|
||||
self.assertTrue(mx.allclose(y[part], y1, atol=self.atol, rtol=self.rtol))
|
||||
|
||||
# And their quant versions
|
||||
qlin = lin.to_quantized()
|
||||
slin1 = shard_linear(qlin, "all-to-sharded")
|
||||
slin2 = shard_linear(qlin, "sharded-to-all")
|
||||
y = qlin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=1e-6, rtol=1e-4))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
# And their quant versions (QuintizedMatmul is not supported on CUDA)
|
||||
if not mx.cuda.is_available():
|
||||
qlin = lin.to_quantized()
|
||||
slin1 = shard_linear(qlin, "all-to-sharded")
|
||||
slin2 = shard_linear(qlin, "sharded-to-all")
|
||||
y = qlin(x)
|
||||
y1 = slin1(x)
|
||||
y2 = slin2(x[part])
|
||||
self.assertTrue(mx.allclose(y, y2, atol=self.atol, rtol=self.rtol))
|
||||
self.assertTrue(mx.allclose(y[part], y1))
|
||||
|
||||
# Check the backward works as expected
|
||||
def dummy_loss(model, x, y):
|
||||
@@ -197,12 +238,18 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][1]["bias"], g2["layers"][1]["bias"], atol=1e-6, rtol=1e-4
|
||||
g1["layers"][1]["bias"],
|
||||
g2["layers"][1]["bias"],
|
||||
atol=self.atol,
|
||||
rtol=self.rtol,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
g1["layers"][3]["bias"], g2["layers"][3]["bias"], atol=1e-6, rtol=1e-4
|
||||
g1["layers"][3]["bias"],
|
||||
g2["layers"][3]["bias"],
|
||||
atol=self.atol,
|
||||
rtol=self.rtol,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -248,3 +295,20 @@ class MLXDistributedCommonTestCase(mlx_tests.MLXTestCase):
|
||||
y1 = mod(x)
|
||||
y2 = smod(x)
|
||||
self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4))
|
||||
|
||||
def test_all_gather(self):
|
||||
world = mx.distributed.init()
|
||||
dtypes = [
|
||||
mx.int8,
|
||||
mx.uint8,
|
||||
mx.int32,
|
||||
mx.uint32,
|
||||
mx.float32,
|
||||
mx.float16,
|
||||
mx.bfloat16,
|
||||
]
|
||||
for dt in dtypes:
|
||||
x = mx.ones((2, 2, 4), dtype=dt)
|
||||
y = mx.distributed.all_gather(x)
|
||||
self.assertEqual(y.shape, (world.size() * 2, 2, 4))
|
||||
self.assertTrue(mx.all(y == 1))
|
||||
|
||||
Reference in New Issue
Block a user