Files
mlx/python/tests/nccl_test_distributed.py

53 lines
1.4 KiB
Python
Raw Normal View History

2025-08-21 20:56:15 +02:00
# Copyright © 2024 Apple Inc.
2025-08-21 20:56:15 +02:00
import mlx.core as mx
import mlx_distributed_tests
2025-08-21 20:56:15 +02:00
import mlx_tests
class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
2025-08-21 20:56:15 +02:00
@classmethod
def setUpClass(cls):
_ = mx.distributed.init(strict=True, backend="nccl")
cls.atol = 1e-4
cls.rtol = 1e-4
def test_sum_scatter(self):
2025-08-21 20:56:15 +02:00
world = mx.distributed.init()
2025-08-21 20:56:15 +02:00
dtypes = [
(mx.float32, 1e-6),
(mx.float16, 5e-3),
(mx.bfloat16, 1e-1),
]
sizes = [
(8,),
(64,),
2025-08-21 20:56:15 +02:00
(1024,),
(1024, 1024),
]
key = mx.random.key(world.rank())
2025-08-21 20:56:15 +02:00
for dt, rtol in dtypes:
for sh in sizes:
x = (mx.random.uniform(shape=sh, key=key) * 10).astype(dt) # shape=sh
2025-08-21 20:56:15 +02:00
# Sum scatter
y = mx.distributed.sum_scatter(x) # shape=sh/world.size()
z = mx.distributed.all_sum(x) # shape=sh
chunk = sh[0] // world.size()
start = world.rank() * chunk
stop = start + chunk
z_ref = z[start:stop]
maxrelerror = (y - z_ref).abs()
2025-08-21 20:56:15 +02:00
if rtol > 0:
maxrelerror /= z_ref.abs()
2025-08-21 20:56:15 +02:00
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()