Fix the test and add custom min/max reductions for uncommon MPI types (#2060)

This commit is contained in:
Angelos Katharopoulos
2025-04-10 17:01:17 -07:00
committed by GitHub
parent dfae2c6989
commit ddaa4b7dcb
4 changed files with 135 additions and 47 deletions

View File

@@ -51,16 +51,25 @@ class TestRingDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
x = (
mx.random.uniform(shape=(world.size(),) + sh, key=key) * 10
).astype(dt)
for reduction in reductions:
reducer_distributed = getattr(mx.distributed, f"all_{reduction}")
y = reducer_distributed(x[world.rank()])
reducer = getattr(mx, reduction)
z = reducer(x, axis=0)
mx.eval(y, z)
# All sum
y = mx.distributed.all_sum(x[world.rank()])
z = x.sum(0)
maxrelerror = (y - z).abs()
if rtol > 0:
maxrelerror /= z.abs()
maxrelerror = maxrelerror.max()
self.assertLessEqual(maxrelerror, rtol)
maxrelerror = ((y - z).abs() / z.abs()).max()
self.assertLessEqual(maxrelerror, rtol)
# All max
y = mx.distributed.all_max(x[world.rank()])
z = x.max(0)
self.assertTrue(mx.all(y == z))
# All min
y = mx.distributed.all_min(x[world.rank()])
z = x.min(0)
self.assertTrue(mx.all(y == z))
def test_all_gather(self):
world = mx.distributed.init()