Min / max reductions (#2041)

This commit is contained in:
Anastasiia Filippova
2025-04-10 08:22:20 +02:00
committed by GitHub
parent 9ecefd56db
commit 515f104926
11 changed files with 276 additions and 27 deletions

View File

@@ -44,17 +44,23 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
(1024, 1024),
]
key = mx.random.key(0)
reductions = ["min", "max", "sum"]
for dt, rtol in dtypes:
for sh in sizes:
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
y = mx.distributed.all_sum(x[world.rank()])
z = sum(
x[i] for i in range(world.size())
) # to ensure that we don't sum to int32
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
for reduction in reductions:
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
y = reducer_distributed(x[world.rank()])
reducer = getattr(mx, reduction)
z = reducer(x, axis=0)
mx.eval(y, z)
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
def test_all_gather(self):
world = mx.distributed.init()